diff --git a/README.md b/README.md index 2df4c9ad1e7..6982cad43e7 100644 --- a/README.md +++ b/README.md @@ -168,6 +168,7 @@ Connect from any of these platforms: | **TypeScript** (React, Next.js, Vue, Svelte, Angular, Node.js, Bun, Deno) | [Get started](https://spacetimedb.com/docs/quickstarts/react) | | **Rust** | [Get started](https://spacetimedb.com/docs/quickstarts/rust) | | **C#** (standalone and Unity) | [Get started](https://spacetimedb.com/docs/quickstarts/c-sharp) | +| **Kotlin** (Android, JVM, iOS/Native) | [Get started](https://spacetimedb.com/docs/quickstarts/kotlin) | | **C++** (Unreal Engine) | [Get started](https://spacetimedb.com/docs/quickstarts/c-plus-plus) | ## Running with Docker diff --git a/crates/cli/build.rs b/crates/cli/build.rs index 06b962c6066..4d87b87d7a2 100644 --- a/crates/cli/build.rs +++ b/crates/cli/build.rs @@ -84,21 +84,44 @@ fn generate_template_files() { .push_str("pub fn get_template_files() -> HashMap<&'static str, HashMap<&'static str, &'static str>> {\n"); generated_code.push_str(" let mut templates = HashMap::new();\n\n"); + let mut binary_code = String::new(); + binary_code.push_str("#[allow(unused_mut)]\n"); + binary_code.push_str( + "pub fn get_template_binary_files() -> HashMap<&'static str, HashMap<&'static str, &'static [u8]>> {\n", + ); + binary_code.push_str(" let mut templates = HashMap::new();\n\n"); + for template in &discovered_templates { if let Some(ref server_source) = template.server_source { let server_path = PathBuf::from(server_source); - generate_template_entry(&mut generated_code, &server_path, server_source, &manifest_dir); + generate_template_entry( + &mut generated_code, + &mut binary_code, + &server_path, + server_source, + &manifest_dir, + ); } if let Some(ref client_source) = template.client_source { let client_path = PathBuf::from(client_source); - generate_template_entry(&mut generated_code, &client_path, client_source, &manifest_dir); + generate_template_entry( + &mut generated_code, + &mut binary_code, + &client_path, + client_source, + &manifest_dir, + ); } } generated_code.push_str(" templates\n"); generated_code.push_str("}\n\n"); + binary_code.push_str(" templates\n"); + binary_code.push_str("}\n\n"); + generated_code.push_str(&binary_code); + let repo_root = get_repo_root(); let workspace_cargo = repo_root.join("Cargo.toml"); println!("cargo:rerun-if-changed={}", workspace_cargo.display()); @@ -297,7 +320,17 @@ where serializer.serialize_str(value.as_deref().unwrap_or("")) } -fn generate_template_entry(code: &mut String, template_path: &Path, source: &str, manifest_dir: &Path) { +fn is_binary_file(path: &str) -> bool { + path.ends_with(".jar") +} + +fn generate_template_entry( + code: &mut String, + binary_code: &mut String, + template_path: &Path, + source: &str, + manifest_dir: &Path, +) { let (git_files, resolved_base) = get_git_tracked_files(template_path, manifest_dir); if git_files.is_empty() { @@ -334,6 +367,9 @@ fn generate_template_entry(code: &mut String, template_path: &Path, source: &str code.push_str(" {\n"); code.push_str(" let mut files = HashMap::new();\n"); + binary_code.push_str(" {\n"); + binary_code.push_str(" let mut files = HashMap::new();\n"); + for file_path in git_files { // Example file_path: modules/chat-console-rs/src/lib.rs (relative to repo root) // Example resolved_base: modules/chat-console-rs @@ -386,15 +422,25 @@ fn generate_template_entry(code: &mut String, template_path: &Path, source: &str // Example include_path (inside crate): "templates/basic-rs/server/src/lib.rs" // Example include_path (outside crate): ".templates/parent_parent_modules_chat-console-rs/src/lib.rs" // Example relative_str: "src/lib.rs" - code.push_str(&format!( - " files.insert(\"{}\", include_str!(concat!(env!(\"CARGO_MANIFEST_DIR\"), \"/{}\")));\n", - relative_str, include_path - )); + if is_binary_file(&relative_str) { + binary_code.push_str(&format!( + " files.insert(\"{}\", include_bytes!(concat!(env!(\"CARGO_MANIFEST_DIR\"), \"/{}\")).as_slice());\n", + relative_str, include_path + )); + } else { + code.push_str(&format!( + " files.insert(\"{}\", include_str!(concat!(env!(\"CARGO_MANIFEST_DIR\"), \"/{}\")));\n", + relative_str, include_path + )); + } } } code.push_str(&format!(" templates.insert(\"{}\", files);\n", source)); code.push_str(" }\n\n"); + + binary_code.push_str(&format!(" templates.insert(\"{}\", files);\n", source)); + binary_code.push_str(" }\n\n"); } /// Get a list of files tracked by git from a given directory diff --git a/crates/cli/src/subcommands/generate.rs b/crates/cli/src/subcommands/generate.rs index 6e5378fede5..ac0ad82b2e6 100644 --- a/crates/cli/src/subcommands/generate.rs +++ b/crates/cli/src/subcommands/generate.rs @@ -6,8 +6,8 @@ use clap::Arg; use clap::ArgAction::{Set, SetTrue}; use fs_err as fs; use spacetimedb_codegen::{ - generate, private_table_names, CodegenOptions, CodegenVisibility, Csharp, Lang, OutputFile, Rust, TypeScript, - UnrealCpp, AUTO_GENERATED_PREFIX, + generate, private_table_names, CodegenOptions, CodegenVisibility, Csharp, Kotlin, Lang, OutputFile, Rust, + TypeScript, UnrealCpp, AUTO_GENERATED_PREFIX, }; use spacetimedb_lib::de::serde::DeserializeWrapper; use spacetimedb_lib::{sats, RawModuleDef}; @@ -20,6 +20,7 @@ use crate::spacetime_config::{ find_and_load_with_env, CommandConfig, CommandSchema, CommandSchemaBuilder, Key, LoadedConfig, SpacetimeConfig, }; use crate::tasks::csharp::dotnet_format; +use crate::tasks::kotlin::ktfmt; use crate::tasks::rust::rustfmt; use crate::util::{resolve_sibling_binary, y_or_n}; use crate::Config; @@ -396,6 +397,9 @@ fn detect_default_language(client_project_dir: &Path) -> anyhow::Result &'static str { match lang { Language::Rust => "rust", Language::Csharp => "csharp", + Language::Kotlin => "kotlin", Language::TypeScript => "typescript", Language::UnrealCpp => "unrealcpp", } @@ -424,6 +429,7 @@ pub fn default_out_dir_for_language(lang: Language) -> Option { match lang { Language::Rust | Language::TypeScript => Some(PathBuf::from("src/module_bindings")), Language::Csharp => Some(PathBuf::from("module_bindings")), + Language::Kotlin => Some(PathBuf::from("module_bindings")), Language::UnrealCpp => None, } } @@ -516,6 +522,7 @@ pub async fn run_prepared_generate_configs( }; &csharp_lang as &dyn Lang } + Language::Kotlin => &Kotlin, Language::UnrealCpp => { unreal_cpp_lang = UnrealCpp { module_name: run.module_name.as_ref().unwrap(), @@ -684,6 +691,7 @@ pub async fn exec_from_entries( #[serde(rename_all = "lowercase")] pub enum Language { Csharp, + Kotlin, TypeScript, Rust, #[serde(alias = "uecpp", alias = "ue5cpp", alias = "unreal")] @@ -692,11 +700,18 @@ pub enum Language { impl clap::ValueEnum for Language { fn value_variants<'a>() -> &'a [Self] { - &[Self::Csharp, Self::TypeScript, Self::Rust, Self::UnrealCpp] + &[ + Self::Csharp, + Self::Kotlin, + Self::TypeScript, + Self::Rust, + Self::UnrealCpp, + ] } fn to_possible_value(&self) -> Option { Some(match self { Self::Csharp => clap::builder::PossibleValue::new("csharp").aliases(["c#", "cs"]), + Self::Kotlin => clap::builder::PossibleValue::new("kotlin").aliases(["kt", "KT"]), Self::TypeScript => clap::builder::PossibleValue::new("typescript").aliases(["ts", "TS"]), Self::Rust => clap::builder::PossibleValue::new("rust").aliases(["rs", "RS"]), Self::UnrealCpp => PossibleValue::new("unrealcpp").aliases(["uecpp", "ue5cpp", "unreal"]), @@ -710,6 +725,7 @@ impl Language { match self { Language::Rust => "Rust", Language::Csharp => "C#", + Language::Kotlin => "Kotlin", Language::TypeScript => "TypeScript", Language::UnrealCpp => "Unreal C++", } @@ -719,6 +735,7 @@ impl Language { match self { Language::Rust => rustfmt(generated_files)?, Language::Csharp => dotnet_format(project_dir, generated_files)?, + Language::Kotlin => ktfmt(generated_files)?, Language::TypeScript => { // TODO: implement formatting. } diff --git a/crates/cli/src/subcommands/init.rs b/crates/cli/src/subcommands/init.rs index b4701b41164..d6cfa37c585 100644 --- a/crates/cli/src/subcommands/init.rs +++ b/crates/cli/src/subcommands/init.rs @@ -86,6 +86,7 @@ pub enum ClientLanguage { Rust, Csharp, TypeScript, + Kotlin, } impl ClientLanguage { @@ -94,6 +95,7 @@ impl ClientLanguage { ClientLanguage::Rust => "rust", ClientLanguage::Csharp => "csharp", ClientLanguage::TypeScript => "typescript", + ClientLanguage::Kotlin => "kotlin", } } @@ -102,6 +104,7 @@ impl ClientLanguage { "rust" => Ok(Some(ClientLanguage::Rust)), "csharp" | "c#" => Ok(Some(ClientLanguage::Csharp)), "typescript" => Ok(Some(ClientLanguage::TypeScript)), + "kotlin" | "kt" => Ok(Some(ClientLanguage::Kotlin)), _ => Err(anyhow!("Unknown client language: {}", s)), } } @@ -1119,6 +1122,32 @@ pub fn update_csproj_client_to_nuget(dir: &Path) -> anyhow::Result<()> { Ok(()) } +/// Sets up a Kotlin client project: updates the project name and makes gradlew executable. +fn setup_kotlin_client(dir: &Path, project_name: &str) -> anyhow::Result<()> { + let settings_path = dir.join("settings.gradle.kts"); + if settings_path.exists() { + let original = fs::read_to_string(&settings_path)?; + let re = regex::Regex::new(r#"rootProject\.name\s*=\s*"[^"]*""#).unwrap(); + let updated = re + .replace(&original, &format!("rootProject.name = \"{}\"", project_name)) + .to_string(); + if updated != original { + fs::write(&settings_path, updated)?; + } + } + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let gradlew = dir.join("gradlew"); + if gradlew.exists() { + fs::set_permissions(&gradlew, fs::Permissions::from_mode(0o755))?; + } + } + + Ok(()) +} + // Helpers fn write_if_changed(path: PathBuf, original: String, root: Element) -> anyhow::Result<()> { @@ -1324,6 +1353,7 @@ fn init_builtin(config: &TemplateConfig, project_path: &Path, is_server_only: bo .ok_or_else(|| anyhow::anyhow!("Template definition missing"))?; let template_files = embedded::get_template_files(); + let template_binary_files = embedded::get_template_binary_files(); if !is_server_only { println!( @@ -1332,7 +1362,7 @@ fn init_builtin(config: &TemplateConfig, project_path: &Path, is_server_only: bo ); let client_source = &template_def.client_source; if let Some(files) = template_files.get(client_source.as_str()) { - copy_embedded_files(files, project_path)?; + copy_embedded_files(files, template_binary_files.get(client_source.as_str()), project_path)?; } else { anyhow::bail!("Client template not found: {}", client_source); } @@ -1353,6 +1383,9 @@ fn init_builtin(config: &TemplateConfig, project_path: &Path, is_server_only: bo Some(ClientLanguage::Csharp) => { update_csproj_client_to_nuget(project_path)?; } + Some(ClientLanguage::Kotlin) => { + setup_kotlin_client(project_path, &config.project_name)?; + } None => {} } } @@ -1364,7 +1397,7 @@ fn init_builtin(config: &TemplateConfig, project_path: &Path, is_server_only: bo let server_dir = project_path.join("spacetimedb"); let server_source = &template_def.server_source; if let Some(files) = template_files.get(server_source.as_str()) { - copy_embedded_files(files, &server_dir)?; + copy_embedded_files(files, template_binary_files.get(server_source.as_str()), &server_dir)?; } else { anyhow::bail!("Server template not found: {}", server_source); } @@ -1389,7 +1422,11 @@ fn init_builtin(config: &TemplateConfig, project_path: &Path, is_server_only: bo Ok(()) } -fn copy_embedded_files(files: &HashMap<&str, &str>, target_dir: &Path) -> anyhow::Result<()> { +fn copy_embedded_files( + files: &HashMap<&str, &str>, + binary_files: Option<&HashMap<&str, &[u8]>>, + target_dir: &Path, +) -> anyhow::Result<()> { for (file_path, content) in files { // Skip .template.json files - they're only for template metadata if file_path.ends_with(".template.json") { @@ -1402,6 +1439,15 @@ fn copy_embedded_files(files: &HashMap<&str, &str>, target_dir: &Path) -> anyhow } fs::write(&full_path, content)?; } + if let Some(binaries) = binary_files { + for (file_path, content) in binaries { + let full_path = target_dir.join(file_path); + if let Some(parent) = full_path.parent() { + fs::create_dir_all(parent)?; + } + fs::write(&full_path, content)?; + } + } Ok(()) } @@ -1512,6 +1558,14 @@ fn print_next_steps(config: &TemplateConfig, _project_path: &Path) -> anyhow::Re ); println!(" spacetime generate --lang csharp --out-dir module_bindings --module-path spacetimedb"); } + (TemplateType::Builtin, Some(ServerLanguage::Rust), Some(ClientLanguage::Kotlin)) => { + println!( + " spacetime publish --module-path spacetimedb {}{}", + if config.use_local { "--server local " } else { "" }, + config.project_name + ); + println!(" ./gradlew run"); + } (TemplateType::Empty, _, Some(ClientLanguage::TypeScript)) => { println!(" npm install"); if config.server_lang.is_some() { diff --git a/crates/cli/src/tasks/kotlin.rs b/crates/cli/src/tasks/kotlin.rs new file mode 100644 index 00000000000..31491943991 --- /dev/null +++ b/crates/cli/src/tasks/kotlin.rs @@ -0,0 +1,31 @@ +use std::ffi::OsString; +use std::path::PathBuf; + +use anyhow::Context; +use itertools::Itertools; + +fn has_ktfmt() -> bool { + duct::cmd!("ktfmt", "--version") + .stdout_null() + .stderr_null() + .run() + .is_ok() +} + +pub(crate) fn ktfmt(files: impl IntoIterator) -> anyhow::Result<()> { + if !has_ktfmt() { + eprintln!("ktfmt not found — skipping Kotlin formatting."); + eprintln!("Install ktfmt from https://github.com/facebook/ktfmt to auto-format generated code."); + return Ok(()); + } + duct::cmd( + "ktfmt", + itertools::chain( + ["--kotlinlang-style"].into_iter().map_into::(), + files.into_iter().map_into(), + ), + ) + .run() + .context("ktfmt failed")?; + Ok(()) +} diff --git a/crates/cli/src/tasks/mod.rs b/crates/cli/src/tasks/mod.rs index 16414efbe97..9d1e30df023 100644 --- a/crates/cli/src/tasks/mod.rs +++ b/crates/cli/src/tasks/mod.rs @@ -60,4 +60,5 @@ pub fn build( pub mod cpp; pub mod csharp; pub mod javascript; +pub mod kotlin; pub mod rust; diff --git a/crates/codegen/src/kotlin.rs b/crates/codegen/src/kotlin.rs new file mode 100644 index 00000000000..6e6f1459b19 --- /dev/null +++ b/crates/codegen/src/kotlin.rs @@ -0,0 +1,2080 @@ +use crate::util::{ + collect_case, is_reducer_invokable, iter_indexes, iter_procedures, iter_reducers, iter_table_names_and_types, + iter_types, print_auto_generated_file_comment, print_auto_generated_version_comment, type_ref_name, +}; +use crate::{CodegenOptions, OutputFile}; + +use super::code_indenter::{CodeIndenter, Indenter}; +use super::Lang; + +use std::ops::Deref; + +use convert_case::{Case, Casing}; +use spacetimedb_lib::sats::layout::PrimitiveType; +use spacetimedb_lib::version::spacetimedb_lib_version; +use spacetimedb_primitives::ColId; +use spacetimedb_schema::def::{IndexAlgorithm, ModuleDef, ReducerDef, TableDef, TypeDef}; +use spacetimedb_schema::identifier::Identifier; +use spacetimedb_schema::schema::TableSchema; +use spacetimedb_schema::type_for_generate::{AlgebraicTypeDef, AlgebraicTypeUse}; + +use std::collections::BTreeSet; + +const INDENT: &str = " "; +const SDK_PKG: &str = "com.clockworklabs.spacetimedb_kotlin_sdk.shared_client"; + +/// Kotlin hard keywords that must be escaped with backticks when used as identifiers. +/// See: https://kotlinlang.org/docs/keyword-reference.html#hard-keywords +const KOTLIN_HARD_KEYWORDS: &[&str] = &[ + "as", + "break", + "class", + "continue", + "do", + "else", + "false", + "for", + "fun", + "if", + "in", + "interface", + "is", + "null", + "object", + "package", + "return", + "super", + "this", + "throw", + "true", + "try", + "typealias", + "typeof", + "val", + "var", + "when", + "while", +]; + +/// Escapes a Kotlin identifier with backticks if it collides with a hard keyword. +fn kotlin_ident(name: String) -> String { + if KOTLIN_HARD_KEYWORDS.contains(&name.as_str()) { + format!("`{name}`") + } else { + name + } +} + +pub struct Kotlin; + +impl Lang for Kotlin { + fn generate_type_files(&self, _module: &ModuleDef, _typ: &TypeDef) -> Vec { + // All types are emitted in a single Types.kt file via generate_global_files. + vec![] + } + + fn generate_table_file_from_schema(&self, module: &ModuleDef, table: &TableDef, schema: TableSchema) -> OutputFile { + let mut output = CodeIndenter::new(String::new(), INDENT); + let out = &mut output; + + print_file_header(out); + writeln!(out); + + let type_ref = table.product_type_ref; + let product_def = module.typespace_for_generate()[type_ref].as_product().unwrap(); + let type_name = type_ref_name(module, type_ref); + let table_name_pascal = table.accessor_name.deref().to_case(Case::Pascal); + + let is_event = table.is_event; + + // Check if this table has user-defined indexes (event tables never have indexes) + let has_unique_index = !is_event + && iter_indexes(table).any(|idx| idx.accessor_name.is_some() && schema.is_unique(&idx.algorithm.columns())); + let has_btree_index = !is_event + && iter_indexes(table) + .any(|idx| idx.accessor_name.is_some() && !schema.is_unique(&idx.algorithm.columns())); + + // Collect indexed column positions for IxCols generation + let mut ix_col_positions: BTreeSet = BTreeSet::new(); + if !is_event { + for idx in iter_indexes(table) { + if let IndexAlgorithm::BTree(btree) = &idx.algorithm { + for col_pos in btree.columns.iter() { + ix_col_positions.insert(col_pos.idx()); + } + } + } + } + let has_ix_cols = !ix_col_positions.is_empty(); + + // Imports + if has_btree_index { + writeln!(out, "import {SDK_PKG}.BTreeIndex"); + } + writeln!(out, "import {SDK_PKG}.Col"); + writeln!(out, "import {SDK_PKG}.DbConnection"); + writeln!(out, "import {SDK_PKG}.EventContext"); + writeln!(out, "import {SDK_PKG}.InternalSpacetimeApi"); + if has_ix_cols { + writeln!(out, "import {SDK_PKG}.IxCol"); + } + if is_event { + writeln!(out, "import {SDK_PKG}.RemoteEventTable"); + } else if table.primary_key.is_some() { + writeln!(out, "import {SDK_PKG}.RemotePersistentTableWithPrimaryKey"); + } else { + writeln!(out, "import {SDK_PKG}.RemotePersistentTable"); + } + writeln!(out, "import {SDK_PKG}.TableCache"); + if has_unique_index { + writeln!(out, "import {SDK_PKG}.UniqueIndex"); + } + writeln!(out, "import {SDK_PKG}.protocol.QueryResult"); + gen_and_print_imports(module, out, product_def.element_types()); + + writeln!(out); + + // Table handle class + let table_marker = if is_event { + "RemoteEventTable" + } else if table.primary_key.is_some() { + "RemotePersistentTableWithPrimaryKey" + } else { + "RemotePersistentTable" + }; + writeln!(out, "/** Client-side handle for the `{}` table. */", table.name.deref()); + writeln!(out, "@OptIn(InternalSpacetimeApi::class)"); + writeln!(out, "class {table_name_pascal}TableHandle internal constructor("); + out.indent(1); + writeln!(out, "private val conn: DbConnection,"); + writeln!(out, "private val tableCache: TableCache<{type_name}, *>,"); + out.dedent(1); + writeln!(out, ") : {table_marker}<{type_name}> {{"); + out.indent(1); + + // Constants + writeln!(out, "companion object {{"); + out.indent(1); + writeln!(out, "const val TABLE_NAME = \"{}\"", table.name.deref()); + writeln!(out); + // Field name constants + for (ident, _) in product_def.elements.iter() { + let const_name = ident.deref().to_case(Case::ScreamingSnake); + writeln!(out, "const val FIELD_{const_name} = \"{}\"", ident.deref()); + } + writeln!(out); + writeln!(out, "fun createTableCache(): TableCache<{type_name}, *> {{"); + out.indent(1); + // Primary key extractor + if let Some(pk_col) = table.primary_key { + let pk_field = table.get_column(pk_col).unwrap(); + let pk_field_camel = kotlin_ident(pk_field.accessor_name.deref().to_case(Case::Camel)); + writeln!( + out, + "return TableCache.withPrimaryKey({{ reader -> {type_name}.decode(reader) }}) {{ row -> row.{pk_field_camel} }}" + ); + } else { + writeln!( + out, + "return TableCache.withContentKey {{ reader -> {type_name}.decode(reader) }}" + ); + } + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // Accessors (event tables don't store rows) + if !is_event { + writeln!(out, "override fun count(): Int = tableCache.count()"); + writeln!(out, "override fun all(): List<{type_name}> = tableCache.all()"); + writeln!(out, "override fun iter(): Sequence<{type_name}> = tableCache.iter()"); + writeln!(out); + } + + // Callbacks + writeln!( + out, + "override fun onInsert(cb: (EventContext, {type_name}) -> Unit) {{ tableCache.onInsert(cb) }}" + ); + writeln!( + out, + "override fun removeOnInsert(cb: (EventContext, {type_name}) -> Unit) {{ tableCache.removeOnInsert(cb) }}" + ); + if !is_event { + writeln!( + out, + "override fun onDelete(cb: (EventContext, {type_name}) -> Unit) {{ tableCache.onDelete(cb) }}" + ); + if table.primary_key.is_some() { + writeln!(out, "override fun onUpdate(cb: (EventContext, {type_name}, {type_name}) -> Unit) {{ tableCache.onUpdate(cb) }}"); + } + writeln!(out, "override fun onBeforeDelete(cb: (EventContext, {type_name}) -> Unit) {{ tableCache.onBeforeDelete(cb) }}"); + writeln!(out); + writeln!(out, "override fun removeOnDelete(cb: (EventContext, {type_name}) -> Unit) {{ tableCache.removeOnDelete(cb) }}"); + if table.primary_key.is_some() { + writeln!(out, "override fun removeOnUpdate(cb: (EventContext, {type_name}, {type_name}) -> Unit) {{ tableCache.removeOnUpdate(cb) }}"); + } + writeln!(out, "override fun removeOnBeforeDelete(cb: (EventContext, {type_name}) -> Unit) {{ tableCache.removeOnBeforeDelete(cb) }}"); + } + writeln!(out); + + // Index properties + let get_field_name_and_type = |col_pos: ColId| -> (String, String) { + let (field_name, field_type) = &product_def.elements[col_pos.idx()]; + let name_camel = kotlin_ident(field_name.deref().to_case(Case::Camel)); + let kt_type = kotlin_type(module, field_type); + (name_camel, kt_type) + }; + + for idx in iter_indexes(table) { + let Some(accessor_name) = idx.accessor_name.as_ref() else { + // System-generated indexes don't get client-side accessors + continue; + }; + + let columns = idx.algorithm.columns(); + let is_unique = schema.is_unique(&columns); + let index_name_camel = kotlin_ident(accessor_name.deref().to_case(Case::Camel)); + let index_class = if is_unique { "UniqueIndex" } else { "BTreeIndex" }; + + match columns.as_singleton() { + Some(col_pos) => { + // Single-column index + let (field_camel, kt_ty) = get_field_name_and_type(col_pos); + writeln!( + out, + "val {index_name_camel} = {index_class}<{type_name}, {kt_ty}>(tableCache) {{ it.{field_camel} }}" + ); + } + None => { + // Multi-column index + let col_fields: Vec<(String, String)> = columns.iter().map(get_field_name_and_type).collect(); + + match col_fields.len() { + 2 => { + let col_types = format!("{}, {}", col_fields[0].1, col_fields[1].1); + let key_expr = format!("Pair(it.{}, it.{})", col_fields[0].0, col_fields[1].0); + writeln!( + out, + "val {index_name_camel} = {index_class}<{type_name}, Pair<{col_types}>>(tableCache) {{ {key_expr} }}" + ); + } + 3 => { + let col_types = format!("{}, {}, {}", col_fields[0].1, col_fields[1].1, col_fields[2].1); + let key_expr = format!( + "Triple(it.{}, it.{}, it.{})", + col_fields[0].0, col_fields[1].0, col_fields[2].0 + ); + writeln!( + out, + "val {index_name_camel} = {index_class}<{type_name}, Triple<{col_types}>>(tableCache) {{ {key_expr} }}" + ); + } + _ => { + let key_expr_fields = col_fields + .iter() + .map(|(name, _)| format!("it.{name}")) + .collect::>() + .join(", "); + writeln!( + out, + "val {index_name_camel} = {index_class}<{type_name}, List>(tableCache) {{ listOf({key_expr_fields}) }}" + ); + } + } + } + } + writeln!(out); + } + + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // --- {Table}Cols class: typed column references for all fields --- + writeln!(out, "@OptIn(InternalSpacetimeApi::class)"); + writeln!(out, "class {table_name_pascal}Cols(tableName: String) {{"); + out.indent(1); + for (ident, field_type) in product_def.elements.iter() { + let field_camel = kotlin_ident(ident.deref().to_case(Case::Camel)); + let col_name = ident.deref(); + let value_type = match field_type { + AlgebraicTypeUse::Option(inner) => kotlin_type(module, inner), + _ => kotlin_type(module, field_type), + }; + writeln!( + out, + "val {field_camel} = Col<{type_name}, {value_type}>(tableName, \"{col_name}\")" + ); + } + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // --- {Table}IxCols class: typed column references for indexed fields only --- + if has_ix_cols { + writeln!(out, "@OptIn(InternalSpacetimeApi::class)"); + writeln!(out, "class {table_name_pascal}IxCols(tableName: String) {{"); + out.indent(1); + for (i, (ident, field_type)) in product_def.elements.iter().enumerate() { + if !ix_col_positions.contains(&i) { + continue; + } + let field_camel = kotlin_ident(ident.deref().to_case(Case::Camel)); + let col_name = ident.deref(); + let value_type = match field_type { + AlgebraicTypeUse::Option(inner) => kotlin_type(module, inner), + _ => kotlin_type(module, field_type), + }; + writeln!( + out, + "val {field_camel} = IxCol<{type_name}, {value_type}>(tableName, \"{col_name}\")" + ); + } + out.dedent(1); + writeln!(out, "}}"); + } else { + // No indexed columns — emit a simple empty class + writeln!(out, "class {table_name_pascal}IxCols"); + } + + OutputFile { + filename: format!("{table_name_pascal}TableHandle.kt"), + code: output.into_inner(), + } + } + + fn generate_reducer_file(&self, module: &ModuleDef, reducer: &ReducerDef) -> OutputFile { + let mut output = CodeIndenter::new(String::new(), INDENT); + let out = &mut output; + + print_file_header(out); + writeln!(out); + + let reducer_name_pascal = reducer.accessor_name.deref().to_case(Case::Pascal); + + // Imports + writeln!(out, "import {SDK_PKG}.bsatn.BsatnReader"); + writeln!(out, "import {SDK_PKG}.bsatn.BsatnWriter"); + gen_and_print_imports(module, out, reducer.params_for_generate.element_types()); + + writeln!(out); + + // Emit args data class with encode/decode (if there are params) + if !reducer.params_for_generate.elements.is_empty() { + writeln!(out, "/** Arguments for the `{}` reducer. */", reducer.name.deref()); + writeln!(out, "data class {reducer_name_pascal}Args("); + out.indent(1); + for (i, (ident, ty)) in reducer.params_for_generate.elements.iter().enumerate() { + let field_name = kotlin_ident(ident.deref().to_case(Case::Camel)); + let kotlin_ty = kotlin_type(module, ty); + let comma = if i + 1 < reducer.params_for_generate.elements.len() { + "," + } else { + "" + }; + writeln!(out, "val {field_name}: {kotlin_ty}{comma}"); + } + out.dedent(1); + writeln!(out, ") {{"); + out.indent(1); + + // encode method + writeln!(out, "/** Encodes these arguments to BSATN. */"); + writeln!(out, "fun encode(): ByteArray {{"); + out.indent(1); + writeln!(out, "val writer = BsatnWriter()"); + for (ident, ty) in reducer.params_for_generate.elements.iter() { + let field_name = kotlin_ident(ident.deref().to_case(Case::Camel)); + write_encode_field(module, out, &field_name, ty); + } + writeln!(out, "return writer.toByteArray()"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // companion object with decode + writeln!(out, "companion object {{"); + out.indent(1); + writeln!(out, "/** Decodes [{reducer_name_pascal}Args] from BSATN. */"); + writeln!(out, "fun decode(reader: BsatnReader): {reducer_name_pascal}Args {{"); + out.indent(1); + for (ident, ty) in reducer.params_for_generate.elements.iter() { + let field_name = kotlin_ident(ident.deref().to_case(Case::Camel)); + write_decode_field(module, out, &field_name, ty); + } + let field_names: Vec = reducer + .params_for_generate + .elements + .iter() + .map(|(ident, _)| kotlin_ident(ident.deref().to_case(Case::Camel))) + .collect(); + let args = field_names.join(", "); + writeln!(out, "return {reducer_name_pascal}Args({args})"); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + } + + // Reducer companion object + writeln!(out, "/** Constants for the `{}` reducer. */", reducer.name.deref()); + writeln!(out, "object {reducer_name_pascal}Reducer {{"); + out.indent(1); + writeln!(out, "const val REDUCER_NAME = \"{}\"", reducer.name.deref()); + out.dedent(1); + writeln!(out, "}}"); + + OutputFile { + filename: format!("{reducer_name_pascal}Reducer.kt"), + code: output.into_inner(), + } + } + + fn generate_procedure_file( + &self, + module: &ModuleDef, + procedure: &spacetimedb_schema::def::ProcedureDef, + ) -> OutputFile { + let mut output = CodeIndenter::new(String::new(), INDENT); + let out = &mut output; + + print_file_header(out); + writeln!(out); + + // Imports + writeln!(out, "import {SDK_PKG}.bsatn.BsatnReader"); + writeln!(out, "import {SDK_PKG}.bsatn.BsatnWriter"); + gen_and_print_imports( + module, + out, + procedure + .params_for_generate + .element_types() + .chain([&procedure.return_type_for_generate]), + ); + + let procedure_name_pascal = procedure.accessor_name.deref().to_case(Case::Pascal); + + if procedure.params_for_generate.elements.is_empty() { + writeln!(out, "object {procedure_name_pascal}Procedure {{"); + out.indent(1); + writeln!(out, "const val PROCEDURE_NAME = \"{}\"", procedure.name.deref()); + let return_ty = kotlin_type(module, &procedure.return_type_for_generate); + writeln!(out, "// Returns: {return_ty}"); + out.dedent(1); + writeln!(out, "}}"); + } else { + writeln!(out, "/** Arguments for the `{}` procedure. */", procedure.name.deref()); + writeln!(out, "data class {procedure_name_pascal}Args("); + out.indent(1); + for (i, (ident, ty)) in procedure.params_for_generate.elements.iter().enumerate() { + let field_name = kotlin_ident(ident.deref().to_case(Case::Camel)); + let kotlin_ty = kotlin_type(module, ty); + let comma = if i + 1 < procedure.params_for_generate.elements.len() { + "," + } else { + "" + }; + writeln!(out, "val {field_name}: {kotlin_ty}{comma}"); + } + out.dedent(1); + writeln!(out, ") {{"); + out.indent(1); + + // encode method + writeln!(out, "/** Encodes these arguments to BSATN. */"); + writeln!(out, "fun encode(): ByteArray {{"); + out.indent(1); + writeln!(out, "val writer = BsatnWriter()"); + for (ident, ty) in procedure.params_for_generate.elements.iter() { + let field_name = kotlin_ident(ident.deref().to_case(Case::Camel)); + write_encode_field(module, out, &field_name, ty); + } + writeln!(out, "return writer.toByteArray()"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // companion object with decode + writeln!(out, "companion object {{"); + out.indent(1); + writeln!(out, "/** Decodes [{procedure_name_pascal}Args] from BSATN. */"); + writeln!(out, "fun decode(reader: BsatnReader): {procedure_name_pascal}Args {{"); + out.indent(1); + for (ident, ty) in procedure.params_for_generate.elements.iter() { + let field_name = kotlin_ident(ident.deref().to_case(Case::Camel)); + write_decode_field(module, out, &field_name, ty); + } + let field_names: Vec = procedure + .params_for_generate + .elements + .iter() + .map(|(ident, _)| kotlin_ident(ident.deref().to_case(Case::Camel))) + .collect(); + let args = field_names.join(", "); + writeln!(out, "return {procedure_name_pascal}Args({args})"); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + writeln!(out, "object {procedure_name_pascal}Procedure {{"); + out.indent(1); + writeln!(out, "const val PROCEDURE_NAME = \"{}\"", procedure.name.deref()); + let return_ty = kotlin_type(module, &procedure.return_type_for_generate); + writeln!(out, "// Returns: {return_ty}"); + out.dedent(1); + writeln!(out, "}}"); + } + + OutputFile { + filename: format!("{procedure_name_pascal}Procedure.kt"), + code: output.into_inner(), + } + } + + fn generate_global_files(&self, module: &ModuleDef, options: &CodegenOptions) -> Vec { + let files = vec![ + generate_types_file(module), + generate_remote_tables_file(module, options), + generate_remote_reducers_file(module, options), + generate_remote_procedures_file(module, options), + generate_module_file(module, options), + ]; + + files + } +} + +// --- Type mapping --- + +fn kotlin_type(module: &ModuleDef, ty: &AlgebraicTypeUse) -> String { + match ty { + AlgebraicTypeUse::Unit => "Unit".to_string(), + AlgebraicTypeUse::Never => "Nothing".to_string(), + AlgebraicTypeUse::Identity => "Identity".to_string(), + AlgebraicTypeUse::ConnectionId => "ConnectionId".to_string(), + AlgebraicTypeUse::Timestamp => "Timestamp".to_string(), + AlgebraicTypeUse::TimeDuration => "TimeDuration".to_string(), + AlgebraicTypeUse::ScheduleAt => "ScheduleAt".to_string(), + AlgebraicTypeUse::Uuid => "SpacetimeUuid".to_string(), + AlgebraicTypeUse::Option(inner_ty) => format!("{}?", kotlin_type(module, inner_ty)), + AlgebraicTypeUse::Result { ok_ty, err_ty } => format!( + "SpacetimeResult<{}, {}>", + kotlin_type(module, ok_ty), + kotlin_type(module, err_ty) + ), + AlgebraicTypeUse::Primitive(prim) => match prim { + PrimitiveType::Bool => "Boolean", + PrimitiveType::I8 => "Byte", + PrimitiveType::U8 => "UByte", + PrimitiveType::I16 => "Short", + PrimitiveType::U16 => "UShort", + PrimitiveType::I32 => "Int", + PrimitiveType::U32 => "UInt", + PrimitiveType::I64 => "Long", + PrimitiveType::U64 => "ULong", + PrimitiveType::I128 => "Int128", + PrimitiveType::U128 => "UInt128", + PrimitiveType::I256 => "Int256", + PrimitiveType::U256 => "UInt256", + PrimitiveType::F32 => "Float", + PrimitiveType::F64 => "Double", + } + .to_string(), + AlgebraicTypeUse::String => "String".to_string(), + AlgebraicTypeUse::Array(elem_ty) => { + if matches!(&**elem_ty, AlgebraicTypeUse::Primitive(PrimitiveType::U8)) { + return "ByteArray".to_string(); + } + format!("List<{}>", kotlin_type(module, elem_ty)) + } + AlgebraicTypeUse::Ref(r) => type_ref_name(module, *r), + } +} + +/// Returns the FQN import path for a type. Used for import statements. +fn kotlin_type_fqn(_module: &ModuleDef, ty: &AlgebraicTypeUse) -> Option { + match ty { + AlgebraicTypeUse::Identity => Some(format!("{SDK_PKG}.type.Identity")), + AlgebraicTypeUse::ConnectionId => Some(format!("{SDK_PKG}.type.ConnectionId")), + AlgebraicTypeUse::Timestamp => Some(format!("{SDK_PKG}.type.Timestamp")), + AlgebraicTypeUse::TimeDuration => Some(format!("{SDK_PKG}.type.TimeDuration")), + AlgebraicTypeUse::ScheduleAt => Some(format!("{SDK_PKG}.type.ScheduleAt")), + AlgebraicTypeUse::Uuid => Some(format!("{SDK_PKG}.type.SpacetimeUuid")), + AlgebraicTypeUse::Result { .. } => Some(format!("{SDK_PKG}.SpacetimeResult")), + AlgebraicTypeUse::Primitive(prim) => match prim { + PrimitiveType::I128 => Some(format!("{SDK_PKG}.Int128")), + PrimitiveType::U128 => Some(format!("{SDK_PKG}.UInt128")), + PrimitiveType::I256 => Some(format!("{SDK_PKG}.Int256")), + PrimitiveType::U256 => Some(format!("{SDK_PKG}.UInt256")), + _ => None, + }, + _ => None, + } +} + +// --- BSATN encode/decode generation helpers --- + +/// Write the BSATN encode call for a single field. +fn write_encode_field(module: &ModuleDef, out: &mut Indenter, field_name: &str, ty: &AlgebraicTypeUse) { + match ty { + AlgebraicTypeUse::Primitive(prim) => { + let method = match prim { + PrimitiveType::Bool => "writeBool", + PrimitiveType::I8 => "writeI8", + PrimitiveType::U8 => "writeU8", + PrimitiveType::I16 => "writeI16", + PrimitiveType::U16 => "writeU16", + PrimitiveType::I32 => "writeI32", + PrimitiveType::U32 => "writeU32", + PrimitiveType::I64 => "writeI64", + PrimitiveType::U64 => "writeU64", + PrimitiveType::F32 => "writeF32", + PrimitiveType::F64 => "writeF64", + PrimitiveType::I128 | PrimitiveType::U128 | PrimitiveType::I256 | PrimitiveType::U256 => { + // These SDK wrapper types have their own encode method + writeln!(out, "{field_name}.encode(writer)"); + return; + } + }; + writeln!(out, "writer.{method}({field_name})"); + } + AlgebraicTypeUse::String => { + writeln!(out, "writer.writeString({field_name})"); + } + AlgebraicTypeUse::Identity + | AlgebraicTypeUse::ConnectionId + | AlgebraicTypeUse::Timestamp + | AlgebraicTypeUse::TimeDuration + | AlgebraicTypeUse::Uuid => { + writeln!(out, "{field_name}.encode(writer)"); + } + AlgebraicTypeUse::ScheduleAt => { + writeln!(out, "{field_name}.encode(writer)"); + } + AlgebraicTypeUse::Ref(_) => { + writeln!(out, "{field_name}.encode(writer)"); + } + AlgebraicTypeUse::Option(inner) => { + writeln!(out, "if ({field_name} != null) {{"); + out.indent(1); + writeln!(out, "writer.writeSumTag(0u)"); + write_encode_value(module, out, field_name, inner); + out.dedent(1); + writeln!(out, "}} else {{"); + out.indent(1); + writeln!(out, "writer.writeSumTag(1u)"); + out.dedent(1); + writeln!(out, "}}"); + } + AlgebraicTypeUse::Array(elem) => { + if matches!(&**elem, AlgebraicTypeUse::Primitive(PrimitiveType::U8)) { + writeln!(out, "writer.writeByteArray({field_name})"); + } else { + writeln!(out, "writer.writeArrayLen({field_name}.size)"); + writeln!(out, "for (elem in {field_name}) {{"); + out.indent(1); + write_encode_value(module, out, "elem", elem); + out.dedent(1); + writeln!(out, "}}"); + } + } + AlgebraicTypeUse::Result { ok_ty, err_ty } => { + writeln!(out, "when ({field_name}) {{"); + out.indent(1); + writeln!(out, "is SpacetimeResult.Ok -> {{"); + out.indent(1); + writeln!(out, "writer.writeSumTag(0u)"); + write_encode_value(module, out, &format!("{field_name}.value"), ok_ty); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out, "is SpacetimeResult.Err -> {{"); + out.indent(1); + writeln!(out, "writer.writeSumTag(1u)"); + write_encode_value(module, out, &format!("{field_name}.error"), err_ty); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + } + AlgebraicTypeUse::Unit => { + // Unit is encoded as empty product — nothing to write + } + AlgebraicTypeUse::Never => { + writeln!(out, "// Never type — unreachable"); + } + } +} + +/// Write encode for a value expression (not a field reference). +fn write_encode_value(module: &ModuleDef, out: &mut Indenter, expr: &str, ty: &AlgebraicTypeUse) { + // For simple primitives, delegate to the same logic + write_encode_field(module, out, expr, ty); +} + +/// Write the BSATN decode expression for a type, returning a string expression. +fn write_decode_expr(module: &ModuleDef, ty: &AlgebraicTypeUse) -> String { + match ty { + AlgebraicTypeUse::Primitive(prim) => { + let method = match prim { + PrimitiveType::Bool => "reader.readBool()", + PrimitiveType::I8 => "reader.readI8()", + PrimitiveType::U8 => "reader.readU8()", + PrimitiveType::I16 => "reader.readI16()", + PrimitiveType::U16 => "reader.readU16()", + PrimitiveType::I32 => "reader.readI32()", + PrimitiveType::U32 => "reader.readU32()", + PrimitiveType::I64 => "reader.readI64()", + PrimitiveType::U64 => "reader.readU64()", + PrimitiveType::F32 => "reader.readF32()", + PrimitiveType::F64 => "reader.readF64()", + PrimitiveType::I128 => "Int128.decode(reader)", + PrimitiveType::U128 => "UInt128.decode(reader)", + PrimitiveType::I256 => "Int256.decode(reader)", + PrimitiveType::U256 => "UInt256.decode(reader)", + }; + method.to_string() + } + AlgebraicTypeUse::String => "reader.readString()".to_string(), + AlgebraicTypeUse::Identity => "Identity.decode(reader)".to_string(), + AlgebraicTypeUse::ConnectionId => "ConnectionId.decode(reader)".to_string(), + AlgebraicTypeUse::Timestamp => "Timestamp.decode(reader)".to_string(), + AlgebraicTypeUse::TimeDuration => "TimeDuration.decode(reader)".to_string(), + AlgebraicTypeUse::ScheduleAt => "ScheduleAt.decode(reader)".to_string(), + AlgebraicTypeUse::Uuid => "SpacetimeUuid.decode(reader)".to_string(), + AlgebraicTypeUse::Ref(r) => { + let name = type_ref_name(module, *r); + format!("{name}.decode(reader)") + } + AlgebraicTypeUse::Unit => "Unit".to_string(), + AlgebraicTypeUse::Never => "error(\"Never type\")".to_string(), + // Option, Array, Result are handled inline in write_decode_field + AlgebraicTypeUse::Option(_) | AlgebraicTypeUse::Array(_) | AlgebraicTypeUse::Result { .. } => { + // These need multi-line decode; handled by write_decode_field + String::new() + } + } +} + +/// Returns true if the type can be decoded as a single expression. +fn is_simple_decode(ty: &AlgebraicTypeUse) -> bool { + !matches!( + ty, + AlgebraicTypeUse::Option(_) | AlgebraicTypeUse::Array(_) | AlgebraicTypeUse::Result { .. } + ) +} + +/// Write the decode for a field, assigning to a val. +fn write_decode_field(module: &ModuleDef, out: &mut Indenter, var_name: &str, ty: &AlgebraicTypeUse) { + match ty { + AlgebraicTypeUse::Option(inner) => { + if is_simple_decode(inner) { + let inner_expr = write_decode_expr(module, inner); + writeln!( + out, + "val {var_name} = if (reader.readSumTag().toInt() == 0) {inner_expr} else null" + ); + } else { + writeln!(out, "val {var_name} = if (reader.readSumTag().toInt() == 0) {{"); + out.indent(1); + write_decode_field(module, out, "__inner", inner); + writeln!(out, "__inner"); + out.dedent(1); + writeln!(out, "}} else null"); + } + } + AlgebraicTypeUse::Array(elem) => { + if matches!(&**elem, AlgebraicTypeUse::Primitive(PrimitiveType::U8)) { + writeln!(out, "val {var_name} = reader.readByteArray()"); + } else if is_simple_decode(elem) { + let elem_expr = write_decode_expr(module, elem); + writeln!(out, "val {var_name} = List(reader.readArrayLen()) {{ {elem_expr} }}"); + } else { + writeln!(out, "val __{var_name}Len = reader.readArrayLen()"); + writeln!(out, "val {var_name} = buildList(__{var_name}Len) {{"); + out.indent(1); + writeln!(out, "repeat(__{var_name}Len) {{"); + out.indent(1); + write_decode_field(module, out, "__elem", elem); + writeln!(out, "add(__elem)"); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + } + } + AlgebraicTypeUse::Result { ok_ty, err_ty } => { + writeln!(out, "val {var_name} = when (reader.readSumTag().toInt()) {{"); + out.indent(1); + if is_simple_decode(ok_ty) { + let ok_expr = write_decode_expr(module, ok_ty); + writeln!(out, "0 -> SpacetimeResult.Ok({ok_expr})"); + } else { + writeln!(out, "0 -> {{"); + out.indent(1); + write_decode_field(module, out, "__ok", ok_ty); + writeln!(out, "SpacetimeResult.Ok(__ok)"); + out.dedent(1); + writeln!(out, "}}"); + } + if is_simple_decode(err_ty) { + let err_expr = write_decode_expr(module, err_ty); + writeln!(out, "1 -> SpacetimeResult.Err({err_expr})"); + } else { + writeln!(out, "1 -> {{"); + out.indent(1); + write_decode_field(module, out, "__err", err_ty); + writeln!(out, "SpacetimeResult.Err(__err)"); + out.dedent(1); + writeln!(out, "}}"); + } + writeln!(out, "else -> error(\"Unknown Result tag\")"); + out.dedent(1); + writeln!(out, "}}"); + } + _ => { + let expr = write_decode_expr(module, ty); + writeln!(out, "val {var_name} = {expr}"); + } + } +} + +// --- File generation helpers --- + +fn print_file_header(output: &mut Indenter) { + print_auto_generated_file_comment(output); + writeln!(output, "@file:Suppress(\"UNUSED\", \"SpellCheckingInspection\")"); + writeln!(output); + writeln!(output, "package module_bindings"); +} + +fn gen_and_print_imports<'a>( + module: &ModuleDef, + out: &mut Indenter, + roots: impl Iterator, +) { + let mut imports = BTreeSet::new(); + + for ty in roots { + collect_type_imports(module, ty, &mut imports); + } + + if !imports.is_empty() { + for import in imports { + writeln!(out, "import {import}"); + } + } +} + +fn collect_type_imports(module: &ModuleDef, ty: &AlgebraicTypeUse, imports: &mut BTreeSet) { + if let Some(fqn) = kotlin_type_fqn(module, ty) { + imports.insert(fqn); + } + match ty { + AlgebraicTypeUse::Result { ok_ty, err_ty } => { + collect_type_imports(module, ok_ty, imports); + collect_type_imports(module, err_ty, imports); + } + AlgebraicTypeUse::Option(inner) => { + collect_type_imports(module, inner, imports); + } + AlgebraicTypeUse::Array(inner) => { + collect_type_imports(module, inner, imports); + } + _ => {} + } +} + +// --- Types.kt --- + +fn generate_types_file(module: &ModuleDef) -> OutputFile { + let mut output = CodeIndenter::new(String::new(), INDENT); + let out = &mut output; + + print_file_header(out); + writeln!(out); + + // Collect imports from all types + let mut imports = BTreeSet::new(); + // Always import BSATN reader/writer for encode/decode + imports.insert(format!("{SDK_PKG}.bsatn.BsatnReader")); + imports.insert(format!("{SDK_PKG}.bsatn.BsatnWriter")); + + for ty in iter_types(module) { + match &module.typespace_for_generate()[ty.ty] { + AlgebraicTypeDef::Product(product) => { + for (_, field_ty) in product.elements.iter() { + collect_type_imports(module, field_ty, &mut imports); + } + } + AlgebraicTypeDef::Sum(sum) => { + for (_, variant_ty) in sum.variants.iter() { + collect_type_imports(module, variant_ty, &mut imports); + } + } + AlgebraicTypeDef::PlainEnum(_) => {} + } + } + if !imports.is_empty() { + for import in &imports { + writeln!(out, "import {import}"); + } + writeln!(out); + } + + let reducer_type_names: BTreeSet = module + .reducers() + .map(|reducer| reducer.accessor_name.deref().to_case(Case::Pascal)) + .collect(); + + for ty in iter_types(module) { + let type_name = collect_case(Case::Pascal, ty.accessor_name.name_segments()); + if reducer_type_names.contains(&type_name) { + continue; + } + + match &module.typespace_for_generate()[ty.ty] { + AlgebraicTypeDef::Product(product) => { + define_product_type(module, out, &type_name, &product.elements); + } + AlgebraicTypeDef::Sum(sum) => { + define_sum_type(module, out, &type_name, &sum.variants); + } + AlgebraicTypeDef::PlainEnum(plain_enum) => { + define_plain_enum(out, &type_name, &plain_enum.variants); + } + } + } + + OutputFile { + filename: "Types.kt".to_string(), + code: output.into_inner(), + } +} + +fn define_product_type( + module: &ModuleDef, + out: &mut Indenter, + name: &str, + elements: &[(Identifier, AlgebraicTypeUse)], +) { + if elements.is_empty() { + writeln!(out, "/** Data type `{name}` from the module schema. */"); + writeln!(out, "data object {name} {{"); + out.indent(1); + writeln!(out, "/** Encodes this value to BSATN. */"); + writeln!(out, "fun encode(writer: BsatnWriter) {{ }}"); + writeln!(out); + writeln!(out, "/** Decodes a [{name}] from BSATN. */"); + writeln!(out, "fun decode(reader: BsatnReader): {name} = {name}"); + out.dedent(1); + writeln!(out, "}}"); + } else { + writeln!(out, "/** Data type `{name}` from the module schema. */"); + writeln!(out, "data class {name}("); + out.indent(1); + for (i, (ident, ty)) in elements.iter().enumerate() { + let field_name = kotlin_ident(ident.deref().to_case(Case::Camel)); + let kotlin_ty = kotlin_type(module, ty); + let comma = if i + 1 < elements.len() { "," } else { "" }; + writeln!(out, "val {field_name}: {kotlin_ty}{comma}"); + } + out.dedent(1); + writeln!(out, ") {{"); + out.indent(1); + + // encode method + writeln!(out, "/** Encodes this value to BSATN. */"); + writeln!(out, "fun encode(writer: BsatnWriter) {{"); + out.indent(1); + for (ident, ty) in elements.iter() { + let field_name = kotlin_ident(ident.deref().to_case(Case::Camel)); + write_encode_field(module, out, &field_name, ty); + } + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // companion object with decode + writeln!(out, "companion object {{"); + out.indent(1); + writeln!(out, "/** Decodes a [{name}] from BSATN. */"); + writeln!(out, "fun decode(reader: BsatnReader): {name} {{"); + out.indent(1); + for (ident, ty) in elements.iter() { + let field_name = kotlin_ident(ident.deref().to_case(Case::Camel)); + write_decode_field(module, out, &field_name, ty); + } + // Constructor call + let field_names: Vec = elements + .iter() + .map(|(ident, _)| kotlin_ident(ident.deref().to_case(Case::Camel))) + .collect(); + let args = field_names.join(", "); + writeln!(out, "return {name}({args})"); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + + // ByteArray fields need custom equals/hashCode + let has_byte_array = elements.iter().any(|(_, ty)| { + matches!(ty, AlgebraicTypeUse::Array(inner) if matches!(&**inner, AlgebraicTypeUse::Primitive(PrimitiveType::U8))) + }); + if has_byte_array { + writeln!(out); + // equals + writeln!(out, "override fun equals(other: Any?): Boolean {{"); + out.indent(1); + writeln!(out, "if (this === other) return true"); + writeln!(out, "if (other !is {name}) return false"); + for (ident, ty) in elements.iter() { + let field_name = kotlin_ident(ident.deref().to_case(Case::Camel)); + if matches!(ty, AlgebraicTypeUse::Array(inner) if matches!(&**inner, AlgebraicTypeUse::Primitive(PrimitiveType::U8))) + { + writeln!(out, "if (!{field_name}.contentEquals(other.{field_name})) return false"); + } else { + writeln!(out, "if ({field_name} != other.{field_name}) return false"); + } + } + writeln!(out, "return true"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + // hashCode + writeln!(out, "override fun hashCode(): Int {{"); + out.indent(1); + writeln!(out, "var result = 0"); + for (ident, ty) in elements.iter() { + let field_name = kotlin_ident(ident.deref().to_case(Case::Camel)); + if matches!(ty, AlgebraicTypeUse::Array(inner) if matches!(&**inner, AlgebraicTypeUse::Primitive(PrimitiveType::U8))) + { + writeln!(out, "result = 31 * result + {field_name}.contentHashCode()"); + } else { + writeln!(out, "result = 31 * result + {field_name}.hashCode()"); + } + } + writeln!(out, "return result"); + out.dedent(1); + writeln!(out, "}}"); + } + + out.dedent(1); + writeln!(out, "}}"); + } + writeln!(out); +} + +/// Returns the Kotlin type name for `ty`, qualifying with `module_bindings.` when +/// a variant name in `variant_names` would shadow the type inside a sealed interface scope. +fn kotlin_type_avoiding_variants(module: &ModuleDef, ty: &AlgebraicTypeUse, variant_names: &[String]) -> String { + let base = kotlin_type(module, ty); + if variant_names.contains(&base) { + format!("module_bindings.{base}") + } else { + base + } +} + +/// Like [write_decode_expr] but qualifies `Ref` types that collide with variant names. +fn write_decode_expr_avoiding_variants(module: &ModuleDef, ty: &AlgebraicTypeUse, variant_names: &[String]) -> String { + if let AlgebraicTypeUse::Ref(r) = ty { + let name = type_ref_name(module, *r); + if variant_names.contains(&name) { + return format!("module_bindings.{name}.decode(reader)"); + } + } + write_decode_expr(module, ty) +} + +fn define_sum_type(module: &ModuleDef, out: &mut Indenter, name: &str, variants: &[(Identifier, AlgebraicTypeUse)]) { + assert!( + variants.len() <= 256, + "Sum type `{name}` has {} variants, but BSATN sum tags are limited to 256", + variants.len() + ); + // Collect all variant names so we can detect when a payload type name collides + // with a variant name (which would resolve to the sealed interface member instead + // of the top-level type). + let variant_names: Vec = variants + .iter() + .map(|(ident, _)| ident.deref().to_case(Case::Pascal)) + .collect(); + + writeln!(out, "/** Sum type `{name}` from the module schema. */"); + writeln!(out, "sealed interface {name} {{"); + out.indent(1); + + // Variants + for (ident, ty) in variants.iter() { + let variant_name = ident.deref().to_case(Case::Pascal); + match ty { + AlgebraicTypeUse::Unit => { + writeln!(out, "data object {variant_name} : {name}"); + } + _ => { + let kotlin_ty = kotlin_type_avoiding_variants(module, ty, &variant_names); + writeln!(out, "data class {variant_name}(val value: {kotlin_ty}) : {name}"); + } + } + } + writeln!(out); + + // encode method + writeln!(out, "fun encode(writer: BsatnWriter) {{"); + out.indent(1); + writeln!(out, "when (this) {{"); + out.indent(1); + for (i, (ident, ty)) in variants.iter().enumerate() { + let variant_name = ident.deref().to_case(Case::Pascal); + let tag = i; + match ty { + AlgebraicTypeUse::Unit => { + writeln!(out, "is {variant_name} -> writer.writeSumTag({tag}u)"); + } + _ => { + writeln!(out, "is {variant_name} -> {{"); + out.indent(1); + writeln!(out, "writer.writeSumTag({tag}u)"); + write_encode_field(module, out, "value", ty); + out.dedent(1); + writeln!(out, "}}"); + } + } + } + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // companion decode + writeln!(out, "companion object {{"); + out.indent(1); + writeln!(out, "fun decode(reader: BsatnReader): {name} {{"); + out.indent(1); + writeln!(out, "return when (val tag = reader.readSumTag().toInt()) {{"); + out.indent(1); + for (i, (ident, ty)) in variants.iter().enumerate() { + let variant_name = ident.deref().to_case(Case::Pascal); + match ty { + AlgebraicTypeUse::Unit => { + writeln!(out, "{i} -> {variant_name}"); + } + _ => { + if is_simple_decode(ty) { + let expr = write_decode_expr_avoiding_variants(module, ty, &variant_names); + writeln!(out, "{i} -> {variant_name}({expr})"); + } else { + writeln!(out, "{i} -> {{"); + out.indent(1); + write_decode_field(module, out, "__value", ty); + writeln!(out, "{variant_name}(__value)"); + out.dedent(1); + writeln!(out, "}}"); + } + } + } + } + writeln!(out, "else -> error(\"Unknown {name} tag: $tag\")"); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); +} + +fn define_plain_enum(out: &mut Indenter, name: &str, variants: &[Identifier]) { + assert!( + variants.len() <= 256, + "Enum `{name}` has {} variants, but BSATN sum tags are limited to 256", + variants.len() + ); + writeln!(out, "/** Enum type `{name}` from the module schema. */"); + writeln!(out, "enum class {name} {{"); + out.indent(1); + for (i, variant) in variants.iter().enumerate() { + let variant_name = variant.deref().to_case(Case::Pascal); + let comma = if i + 1 < variants.len() { "," } else { ";" }; + writeln!(out, "{variant_name}{comma}"); + } + writeln!(out); + writeln!(out, "fun encode(writer: BsatnWriter) {{"); + out.indent(1); + writeln!(out, "writer.writeSumTag(ordinal.toUByte())"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + writeln!(out, "companion object {{"); + out.indent(1); + writeln!(out, "fun decode(reader: BsatnReader): {name} {{"); + out.indent(1); + writeln!(out, "val tag = reader.readSumTag().toInt()"); + writeln!( + out, + "return entries.getOrElse(tag) {{ error(\"Unknown {name} tag: $tag\") }}" + ); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); +} + +// --- RemoteTables.kt --- + +fn generate_remote_tables_file(module: &ModuleDef, options: &CodegenOptions) -> OutputFile { + let mut output = CodeIndenter::new(String::new(), INDENT); + let out = &mut output; + + print_file_header(out); + writeln!(out); + + writeln!(out, "import {SDK_PKG}.ClientCache"); + writeln!(out, "import {SDK_PKG}.DbConnection"); + writeln!(out, "import {SDK_PKG}.InternalSpacetimeApi"); + writeln!(out, "import {SDK_PKG}.ModuleTables"); + writeln!(out); + + writeln!(out, "/** Generated table accessors for all tables in this module. */"); + writeln!(out, "@OptIn(InternalSpacetimeApi::class)"); + writeln!(out, "class RemoteTables internal constructor("); + out.indent(1); + writeln!(out, "private val conn: DbConnection,"); + writeln!(out, "private val clientCache: ClientCache,"); + out.dedent(1); + writeln!(out, ") : ModuleTables {{"); + out.indent(1); + + for (_, accessor_name, product_type_ref) in iter_table_names_and_types(module, options.visibility) { + let table_name_pascal = accessor_name.deref().to_case(Case::Pascal); + let table_name_camel = kotlin_ident(accessor_name.deref().to_case(Case::Camel)); + let type_name = type_ref_name(module, product_type_ref); + + writeln!(out, "val {table_name_camel}: {table_name_pascal}TableHandle by lazy {{"); + out.indent(1); + writeln!(out, "@Suppress(\"UNCHECKED_CAST\")"); + writeln!( + out, + "val cache = clientCache.getOrCreateTable<{type_name}>({table_name_pascal}TableHandle.TABLE_NAME) {{" + ); + out.indent(1); + writeln!(out, "{table_name_pascal}TableHandle.createTableCache()"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out, "{table_name_pascal}TableHandle(conn, cache)"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + } + + out.dedent(1); + writeln!(out, "}}"); + + OutputFile { + filename: "RemoteTables.kt".to_string(), + code: output.into_inner(), + } +} + +// --- RemoteReducers.kt --- + +fn generate_remote_reducers_file(module: &ModuleDef, options: &CodegenOptions) -> OutputFile { + let mut output = CodeIndenter::new(String::new(), INDENT); + let out = &mut output; + + print_file_header(out); + writeln!(out); + + // Collect all imports needed by reducer params + let mut imports = BTreeSet::new(); + imports.insert(format!("{SDK_PKG}.CallbackList")); + imports.insert(format!("{SDK_PKG}.DbConnection")); + imports.insert(format!("{SDK_PKG}.EventContext")); + imports.insert(format!("{SDK_PKG}.InternalSpacetimeApi")); + imports.insert(format!("{SDK_PKG}.ModuleReducers")); + imports.insert(format!("{SDK_PKG}.Status")); + + for reducer in iter_reducers(module, options.visibility) { + for (_, ty) in reducer.params_for_generate.elements.iter() { + collect_type_imports(module, ty, &mut imports); + } + } + + for import in &imports { + writeln!(out, "import {import}"); + } + writeln!(out); + + writeln!(out, "/** Generated reducer call methods and callback registration. */"); + writeln!(out, "@OptIn(InternalSpacetimeApi::class)"); + writeln!(out, "class RemoteReducers internal constructor("); + out.indent(1); + writeln!(out, "private val conn: DbConnection,"); + out.dedent(1); + writeln!(out, ") : ModuleReducers {{"); + out.indent(1); + + // --- Invocation methods --- + for reducer in iter_reducers(module, options.visibility) { + if !is_reducer_invokable(reducer) { + continue; + } + + let reducer_name_camel = kotlin_ident(reducer.accessor_name.deref().to_case(Case::Camel)); + let reducer_name_pascal = reducer.accessor_name.deref().to_case(Case::Pascal); + + if reducer.params_for_generate.elements.is_empty() { + writeln!( + out, + "fun {reducer_name_camel}(callback: ((EventContext.Reducer) -> Unit)? = null) {{" + ); + out.indent(1); + writeln!( + out, + "conn.callReducer({reducer_name_pascal}Reducer.REDUCER_NAME, ByteArray(0), Unit, callback)" + ); + out.dedent(1); + writeln!(out, "}}"); + } else { + let params: Vec = reducer + .params_for_generate + .elements + .iter() + .map(|(ident, ty)| { + let name = kotlin_ident(ident.deref().to_case(Case::Camel)); + let kotlin_ty = kotlin_type(module, ty); + format!("{name}: {kotlin_ty}") + }) + .collect(); + let params_str = params.join(", "); + writeln!(out, "fun {reducer_name_camel}({params_str}, callback: ((EventContext.Reducer<{reducer_name_pascal}Args>) -> Unit)? = null) {{"); + out.indent(1); + // Build the args object + let arg_names: Vec = reducer + .params_for_generate + .elements + .iter() + .map(|(ident, _)| kotlin_ident(ident.deref().to_case(Case::Camel))) + .collect(); + let arg_names_str = arg_names.join(", "); + writeln!(out, "val args = {reducer_name_pascal}Args({arg_names_str})"); + writeln!( + out, + "conn.callReducer({reducer_name_pascal}Reducer.REDUCER_NAME, args.encode(), args, callback)" + ); + out.dedent(1); + writeln!(out, "}}"); + } + writeln!(out); + } + + // --- Per-reducer persistent callbacks --- + for reducer in iter_reducers(module, options.visibility) { + let reducer_name_pascal = reducer.accessor_name.deref().to_case(Case::Pascal); + + // Build the typed callback signature: (EventContext.Reducer, arg1Type, arg2Type, ...) -> Unit + let args_type = if reducer.params_for_generate.elements.is_empty() { + "Unit".to_string() + } else { + format!("{reducer_name_pascal}Args") + }; + let cb_params: Vec = std::iter::once(format!("EventContext.Reducer<{args_type}>")) + .chain( + reducer + .params_for_generate + .elements + .iter() + .map(|(_, ty)| kotlin_type(module, ty)), + ) + .collect(); + let cb_type = format!("({}) -> Unit", cb_params.join(", ")); + + // Callback list + writeln!( + out, + "private val on{reducer_name_pascal}Callbacks = CallbackList<{cb_type}>()" + ); + writeln!(out); + + // on{Reducer} + writeln!(out, "fun on{reducer_name_pascal}(cb: {cb_type}) {{"); + out.indent(1); + writeln!(out, "on{reducer_name_pascal}Callbacks.add(cb)"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // removeOn{Reducer} + writeln!(out, "fun removeOn{reducer_name_pascal}(cb: {cb_type}) {{"); + out.indent(1); + writeln!(out, "on{reducer_name_pascal}Callbacks.remove(cb)"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + } + + // --- Unhandled reducer error fallback --- + writeln!( + out, + "private val onUnhandledReducerErrorCallbacks = CallbackList<(EventContext.Reducer<*>) -> Unit>()" + ); + writeln!(out); + writeln!( + out, + "/** Register a callback for reducer errors with no specific handler. */" + ); + writeln!( + out, + "fun onUnhandledReducerError(cb: (EventContext.Reducer<*>) -> Unit) {{" + ); + out.indent(1); + writeln!(out, "onUnhandledReducerErrorCallbacks.add(cb)"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + writeln!( + out, + "fun removeOnUnhandledReducerError(cb: (EventContext.Reducer<*>) -> Unit) {{" + ); + out.indent(1); + writeln!(out, "onUnhandledReducerErrorCallbacks.remove(cb)"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // --- handleReducerEvent dispatch --- + writeln!(out, "internal fun handleReducerEvent(ctx: EventContext.Reducer<*>) {{"); + out.indent(1); + writeln!(out, "when (ctx.reducerName) {{"); + out.indent(1); + + for reducer in iter_reducers(module, options.visibility) { + let reducer_name_pascal = reducer.accessor_name.deref().to_case(Case::Pascal); + + writeln!(out, "{reducer_name_pascal}Reducer.REDUCER_NAME -> {{"); + out.indent(1); + writeln!(out, "if (on{reducer_name_pascal}Callbacks.isNotEmpty()) {{"); + out.indent(1); + + if reducer.params_for_generate.elements.is_empty() { + writeln!(out, "@Suppress(\"UNCHECKED_CAST\")"); + writeln!(out, "val typedCtx = ctx as EventContext.Reducer"); + writeln!(out, "on{reducer_name_pascal}Callbacks.forEach {{ it(typedCtx) }}"); + } else { + writeln!(out, "@Suppress(\"UNCHECKED_CAST\")"); + writeln!( + out, + "val typedCtx = ctx as EventContext.Reducer<{reducer_name_pascal}Args>" + ); + // Build the call args from typed args fields + let call_args: Vec = std::iter::once("typedCtx".to_string()) + .chain(reducer.params_for_generate.elements.iter().map(|(ident, _)| { + let field_name = kotlin_ident(ident.deref().to_case(Case::Camel)); + format!("typedCtx.args.{field_name}") + })) + .collect(); + let call_args_str = call_args.join(", "); + writeln!( + out, + "on{reducer_name_pascal}Callbacks.forEach {{ it({call_args_str}) }}" + ); + } + + out.dedent(1); + writeln!(out, "}} else if (ctx.status is Status.Failed) {{"); + out.indent(1); + writeln!(out, "onUnhandledReducerErrorCallbacks.forEach {{ it(ctx) }}"); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + } + + // Fallback for unknown reducer names + writeln!(out, "else -> {{"); + out.indent(1); + writeln!(out, "if (ctx.status is Status.Failed) {{"); + out.indent(1); + writeln!(out, "onUnhandledReducerErrorCallbacks.forEach {{ it(ctx) }}"); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + + out.dedent(1); + writeln!(out, "}}"); + + OutputFile { + filename: "RemoteReducers.kt".to_string(), + code: output.into_inner(), + } +} + +// --- RemoteProcedures.kt --- + +fn generate_remote_procedures_file(module: &ModuleDef, options: &CodegenOptions) -> OutputFile { + let mut output = CodeIndenter::new(String::new(), INDENT); + let out = &mut output; + + print_file_header(out); + writeln!(out); + + // Collect all imports needed by procedure params and return types + let mut imports = BTreeSet::new(); + imports.insert(format!("{SDK_PKG}.DbConnection")); + imports.insert(format!("{SDK_PKG}.InternalSpacetimeApi")); + imports.insert(format!("{SDK_PKG}.ModuleProcedures")); + + let has_procedures = iter_procedures(module, options.visibility).next().is_some(); + if has_procedures { + imports.insert(format!("{SDK_PKG}.EventContext")); + imports.insert(format!("{SDK_PKG}.ProcedureError")); + imports.insert(format!("{SDK_PKG}.SdkResult")); + imports.insert(format!("{SDK_PKG}.bsatn.BsatnWriter")); + imports.insert(format!("{SDK_PKG}.bsatn.BsatnReader")); + imports.insert(format!("{SDK_PKG}.protocol.ServerMessage")); + imports.insert(format!("{SDK_PKG}.protocol.ProcedureStatus")); + } + + for procedure in iter_procedures(module, options.visibility) { + for (_, ty) in procedure.params_for_generate.elements.iter() { + collect_type_imports(module, ty, &mut imports); + } + collect_type_imports(module, &procedure.return_type_for_generate, &mut imports); + } + + for import in &imports { + writeln!(out, "import {import}"); + } + writeln!(out); + + writeln!( + out, + "/** Generated procedure call methods and callback registration. */" + ); + writeln!(out, "@OptIn(InternalSpacetimeApi::class)"); + writeln!(out, "class RemoteProcedures internal constructor("); + out.indent(1); + writeln!(out, "private val conn: DbConnection,"); + out.dedent(1); + writeln!(out, ") : ModuleProcedures {{"); + out.indent(1); + + for procedure in iter_procedures(module, options.visibility) { + let procedure_name_camel = kotlin_ident(procedure.accessor_name.deref().to_case(Case::Camel)); + let procedure_name_pascal = procedure.accessor_name.deref().to_case(Case::Pascal); + let return_ty = &procedure.return_type_for_generate; + let return_ty_str = kotlin_type(module, return_ty); + let is_unit_return = matches!(return_ty, AlgebraicTypeUse::Unit); + + // Build parameter list + let params: Vec = procedure + .params_for_generate + .elements + .iter() + .map(|(ident, ty)| { + let name = kotlin_ident(ident.deref().to_case(Case::Camel)); + let kotlin_ty = kotlin_type(module, ty); + format!("{name}: {kotlin_ty}") + }) + .collect(); + + // Callback type uses SdkResult to surface both success and ProcedureError + let callback_type = if is_unit_return { + "((EventContext.Procedure, SdkResult) -> Unit)?".to_string() + } else { + format!("((EventContext.Procedure, SdkResult<{return_ty_str}, ProcedureError>) -> Unit)?") + }; + + if params.is_empty() { + writeln!(out, "fun {procedure_name_camel}(callback: {callback_type} = null) {{"); + } else { + let params_str = params.join(", "); + writeln!( + out, + "fun {procedure_name_camel}({params_str}, callback: {callback_type} = null) {{" + ); + } + out.indent(1); + + let args_expr = if procedure.params_for_generate.elements.is_empty() { + "ByteArray(0)".to_string() + } else { + let arg_names: Vec = procedure + .params_for_generate + .elements + .iter() + .map(|(ident, _)| kotlin_ident(ident.deref().to_case(Case::Camel))) + .collect(); + let arg_names_str = arg_names.join(", "); + writeln!(out, "val args = {procedure_name_pascal}Args({arg_names_str})"); + "args.encode()".to_string() + }; + + // Generate wrapper callback that decodes the return value into a Result + writeln!(out, "val wrappedCallback = callback?.let {{ userCb ->"); + out.indent(1); + writeln!( + out, + "{{ ctx: EventContext.Procedure, msg: ServerMessage.ProcedureResultMsg ->" + ); + out.indent(1); + writeln!(out, "when (val status = msg.status) {{"); + out.indent(1); + writeln!(out, "is ProcedureStatus.Returned -> {{"); + out.indent(1); + if is_unit_return { + writeln!(out, "userCb(ctx, SdkResult.Success(Unit))"); + } else if is_simple_decode(return_ty) { + writeln!(out, "val reader = BsatnReader(status.value)"); + let decode_expr = write_decode_expr(module, return_ty); + writeln!(out, "userCb(ctx, SdkResult.Success({decode_expr}))"); + } else { + writeln!(out, "val reader = BsatnReader(status.value)"); + write_decode_field(module, out, "__retVal", return_ty); + writeln!(out, "userCb(ctx, SdkResult.Success(__retVal))"); + } + out.dedent(1); + writeln!(out, "}}"); + writeln!(out, "is ProcedureStatus.InternalError -> {{"); + out.indent(1); + writeln!( + out, + "userCb(ctx, SdkResult.Failure(ProcedureError.InternalError(status.message)))" + ); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + out.dedent(1); + writeln!(out, "}}"); + + writeln!( + out, + "conn.callProcedure({procedure_name_pascal}Procedure.PROCEDURE_NAME, {args_expr}, wrappedCallback)" + ); + + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + } + + out.dedent(1); + writeln!(out, "}}"); + + OutputFile { + filename: "RemoteProcedures.kt".to_string(), + code: output.into_inner(), + } +} + +// --- Module.kt --- + +fn generate_module_file(module: &ModuleDef, options: &CodegenOptions) -> OutputFile { + let mut output = CodeIndenter::new(String::new(), INDENT); + let out = &mut output; + + print_file_header(out); + print_auto_generated_version_comment(out); + writeln!(out); + + writeln!(out, "import {SDK_PKG}.ClientCache"); + writeln!(out, "import {SDK_PKG}.DbConnection"); + writeln!(out, "import {SDK_PKG}.DbConnectionView"); + writeln!(out, "import {SDK_PKG}.EventContext"); + writeln!(out, "import {SDK_PKG}.InternalSpacetimeApi"); + writeln!(out, "import {SDK_PKG}.ModuleAccessors"); + writeln!(out, "import {SDK_PKG}.ModuleDescriptor"); + writeln!(out, "import {SDK_PKG}.Query"); + writeln!(out, "import {SDK_PKG}.SubscriptionBuilder"); + writeln!(out, "import {SDK_PKG}.Table"); + writeln!(out); + + // RemoteModule object with version info and table/reducer/procedure names + writeln!(out, "/**"); + writeln!(out, " * Module metadata generated by the SpacetimeDB CLI."); + writeln!( + out, + " * Contains version info and the names of all tables, reducers, and procedures." + ); + writeln!(out, " */"); + writeln!(out, "@OptIn(InternalSpacetimeApi::class)"); + writeln!(out, "object RemoteModule : ModuleDescriptor {{"); + out.indent(1); + + writeln!( + out, + "override val cliVersion: String = \"{}\"", + spacetimedb_lib_version() + ); + writeln!(out); + + // Table and view names list + writeln!(out, "val tableNames: List = listOf("); + out.indent(1); + for (name, _, _) in iter_table_names_and_types(module, options.visibility) { + writeln!(out, "\"{}\",", name.deref()); + } + out.dedent(1); + writeln!(out, ")"); + writeln!(out); + + // Subscribable (persistent) table/view names — excludes event tables + writeln!(out, "override val subscribableTableNames: List = listOf("); + out.indent(1); + for (name, _, _) in iter_table_names_and_types(module, options.visibility) { + // Event tables are not subscribable; views are never event tables. + let is_event = module.tables().any(|t| t.name == *name && t.is_event); + if !is_event { + writeln!(out, "\"{}\",", name.deref()); + } + } + out.dedent(1); + writeln!(out, ")"); + writeln!(out); + + // Reducer names list + writeln!(out, "val reducerNames: List = listOf("); + out.indent(1); + for reducer in iter_reducers(module, options.visibility) { + if !is_reducer_invokable(reducer) { + continue; + } + writeln!(out, "\"{}\",", reducer.name.deref()); + } + out.dedent(1); + writeln!(out, ")"); + writeln!(out); + + // Procedure names list + writeln!(out, "val procedureNames: List = listOf("); + out.indent(1); + for procedure in iter_procedures(module, options.visibility) { + writeln!(out, "\"{}\",", procedure.name.deref()); + } + out.dedent(1); + writeln!(out, ")"); + + writeln!(out); + + // registerTables() — ModuleDescriptor implementation + writeln!(out, "override fun registerTables(cache: ClientCache) {{"); + out.indent(1); + for (_, accessor_name, _) in iter_table_names_and_types(module, options.visibility) { + let table_name_pascal = accessor_name.deref().to_case(Case::Pascal); + writeln!( + out, + "cache.register({table_name_pascal}TableHandle.TABLE_NAME, {table_name_pascal}TableHandle.createTableCache())" + ); + } + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // createAccessors() — ModuleDescriptor implementation + writeln!( + out, + "override fun createAccessors(conn: DbConnection): ModuleAccessors {{" + ); + out.indent(1); + writeln!(out, "return ModuleAccessors("); + out.indent(1); + writeln!(out, "tables = RemoteTables(conn, conn.clientCache),"); + writeln!(out, "reducers = RemoteReducers(conn),"); + writeln!(out, "procedures = RemoteProcedures(conn),"); + out.dedent(1); + writeln!(out, ")"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // handleReducerEvent() — ModuleDescriptor implementation + writeln!( + out, + "override fun handleReducerEvent(conn: DbConnection, ctx: EventContext.Reducer<*>) {{" + ); + out.indent(1); + writeln!(out, "conn.reducers.handleReducerEvent(ctx)"); + out.dedent(1); + writeln!(out, "}}"); + + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // Extension properties on DbConnection + writeln!(out, "/**"); + writeln!(out, " * Typed table accessors for this module's tables."); + writeln!(out, " */"); + writeln!(out, "val DbConnection.db: RemoteTables"); + out.indent(1); + writeln!(out, "get() = moduleTables as RemoteTables"); + out.dedent(1); + writeln!(out); + + writeln!(out, "/**"); + writeln!(out, " * Typed reducer call functions for this module's reducers."); + writeln!(out, " */"); + writeln!(out, "val DbConnection.reducers: RemoteReducers"); + out.indent(1); + writeln!(out, "get() = moduleReducers as RemoteReducers"); + out.dedent(1); + writeln!(out); + + writeln!(out, "/**"); + writeln!(out, " * Typed procedure call functions for this module's procedures."); + writeln!(out, " */"); + writeln!(out, "val DbConnection.procedures: RemoteProcedures"); + out.indent(1); + writeln!(out, "get() = moduleProcedures as RemoteProcedures"); + out.dedent(1); + writeln!(out); + + // Extension properties on DbConnectionView (exposed via EventContext.connection) + writeln!(out, "/**"); + writeln!(out, " * Typed table accessors for this module's tables."); + writeln!(out, " */"); + writeln!(out, "val DbConnectionView.db: RemoteTables"); + out.indent(1); + writeln!(out, "get() = moduleTables as RemoteTables"); + out.dedent(1); + writeln!(out); + + writeln!(out, "/**"); + writeln!(out, " * Typed reducer call functions for this module's reducers."); + writeln!(out, " */"); + writeln!(out, "val DbConnectionView.reducers: RemoteReducers"); + out.indent(1); + writeln!(out, "get() = moduleReducers as RemoteReducers"); + out.dedent(1); + writeln!(out); + + writeln!(out, "/**"); + writeln!(out, " * Typed procedure call functions for this module's procedures."); + writeln!(out, " */"); + writeln!(out, "val DbConnectionView.procedures: RemoteProcedures"); + out.indent(1); + writeln!(out, "get() = moduleProcedures as RemoteProcedures"); + out.dedent(1); + writeln!(out); + + // Extension properties on EventContext for typed access in callbacks + writeln!(out, "/**"); + writeln!(out, " * Typed table accessors available directly on event context."); + writeln!(out, " */"); + writeln!(out, "val EventContext.db: RemoteTables"); + out.indent(1); + writeln!(out, "get() = connection.db"); + out.dedent(1); + writeln!(out); + + writeln!(out, "/**"); + writeln!( + out, + " * Typed reducer call functions available directly on event context." + ); + writeln!(out, " */"); + writeln!(out, "val EventContext.reducers: RemoteReducers"); + out.indent(1); + writeln!(out, "get() = connection.reducers"); + out.dedent(1); + writeln!(out); + + writeln!(out, "/**"); + writeln!( + out, + " * Typed procedure call functions available directly on event context." + ); + writeln!(out, " */"); + writeln!(out, "val EventContext.procedures: RemoteProcedures"); + out.indent(1); + writeln!(out, "get() = connection.procedures"); + out.dedent(1); + writeln!(out); + + // Builder extension for zero-config setup + writeln!(out, "/**"); + writeln!(out, " * Registers this module's tables with the connection builder."); + writeln!( + out, + " * Call this on the builder to enable typed [db], [reducers], and [procedures] accessors." + ); + writeln!(out, " *"); + writeln!(out, " * Example:"); + writeln!(out, " * ```kotlin"); + writeln!(out, " * val conn = DbConnection.Builder()"); + writeln!(out, " * .withUri(\"ws://localhost:3000\")"); + writeln!(out, " * .withDatabaseName(\"my_module\")"); + writeln!(out, " * .withModuleBindings()"); + writeln!(out, " * .build()"); + writeln!(out, " * ```"); + writeln!(out, " */"); + writeln!(out, "@OptIn(InternalSpacetimeApi::class)"); + writeln!( + out, + "fun DbConnection.Builder.withModuleBindings(): DbConnection.Builder {{" + ); + out.indent(1); + writeln!(out, "return withModule(RemoteModule)"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // QueryBuilder — typed per-table query builder + writeln!(out, "/**"); + writeln!(out, " * Type-safe query builder for this module's tables."); + writeln!(out, " * Supports WHERE predicates and semi-joins."); + writeln!(out, " */"); + writeln!(out, "class QueryBuilder {{"); + out.indent(1); + for (name, accessor_name, product_type_ref) in iter_table_names_and_types(module, options.visibility) { + let table_name = name.deref(); + let type_name = type_ref_name(module, product_type_ref); + let table_name_pascal = accessor_name.deref().to_case(Case::Pascal); + let method_name = kotlin_ident(accessor_name.deref().to_case(Case::Camel)); + + // Check if this table has indexed columns (views have none) + let has_ix = module + .tables() + .find(|t| t.name == *name) + .is_some_and(|t| iter_indexes(t).any(|idx| matches!(&idx.algorithm, IndexAlgorithm::BTree(_)))); + + if has_ix { + writeln!( + out, + "fun {method_name}(): Table<{type_name}, {table_name_pascal}Cols, {table_name_pascal}IxCols> = Table(\"{table_name}\", {table_name_pascal}Cols(\"{table_name}\"), {table_name_pascal}IxCols(\"{table_name}\"))" + ); + } else { + writeln!( + out, + "fun {method_name}(): Table<{type_name}, {table_name_pascal}Cols, {table_name_pascal}IxCols> = Table(\"{table_name}\", {table_name_pascal}Cols(\"{table_name}\"), {table_name_pascal}IxCols())" + ); + } + } + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // Typed addQuery extension on SubscriptionBuilder + writeln!(out, "/**"); + writeln!(out, " * Add a type-safe table query to this subscription."); + writeln!(out, " *"); + writeln!(out, " * Example:"); + writeln!(out, " * ```kotlin"); + writeln!(out, " * conn.subscriptionBuilder()"); + writeln!(out, " * .addQuery {{ qb -> qb.player() }}"); + writeln!( + out, + " * .addQuery {{ qb -> qb.player().where {{ c -> c.health.gt(50) }} }}" + ); + writeln!(out, " * .subscribe()"); + writeln!(out, " * ```"); + writeln!(out, " */"); + writeln!( + out, + "fun SubscriptionBuilder.addQuery(build: (QueryBuilder) -> Query<*>): SubscriptionBuilder {{" + ); + out.indent(1); + writeln!(out, "return addQuery(build(QueryBuilder()).toSql())"); + out.dedent(1); + writeln!(out, "}}"); + writeln!(out); + + // Generated subscribeToAllTables with baked-in queries via QueryBuilder + writeln!(out, "/**"); + writeln!(out, " * Subscribe to all persistent tables in this module."); + writeln!( + out, + " * Event tables are excluded because the server does not support subscribing to them." + ); + writeln!(out, " */"); + writeln!( + out, + "fun SubscriptionBuilder.subscribeToAllTables(): {SDK_PKG}.SubscriptionHandle {{" + ); + out.indent(1); + writeln!(out, "val qb = QueryBuilder()"); + for (name, accessor_name, _) in iter_table_names_and_types(module, options.visibility) { + // Event tables are not subscribable; views are never event tables. + let is_event = module.tables().any(|t| t.name == *name && t.is_event); + if !is_event { + let method_name = kotlin_ident(accessor_name.deref().to_case(Case::Camel)); + writeln!(out, "addQuery(qb.{method_name}().toSql())"); + } + } + writeln!(out, "return subscribe()"); + out.dedent(1); + writeln!(out, "}}"); + + OutputFile { + filename: "Module.kt".to_string(), + code: output.into_inner(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn kotlin_ident_escapes_hard_keywords() { + for &kw in KOTLIN_HARD_KEYWORDS { + assert_eq!( + kotlin_ident(kw.to_string()), + format!("`{kw}`"), + "Expected keyword '{kw}' to be backtick-escaped" + ); + } + } + + #[test] + fn kotlin_ident_passes_through_non_keywords() { + let non_keywords = ["name", "age", "id", "foo", "bar", "myField", "data", "value"]; + for &name in &non_keywords { + assert_eq!( + kotlin_ident(name.to_string()), + name, + "Non-keyword '{name}' should not be escaped" + ); + } + } + + #[test] + fn kotlin_ident_is_case_sensitive() { + // PascalCase versions of keywords are NOT keywords + assert_eq!(kotlin_ident("Object".to_string()), "Object"); + assert_eq!(kotlin_ident("Class".to_string()), "Class"); + assert_eq!(kotlin_ident("When".to_string()), "When"); + assert_eq!(kotlin_ident("Val".to_string()), "Val"); + // But lowercase versions are + assert_eq!(kotlin_ident("object".to_string()), "`object`"); + assert_eq!(kotlin_ident("class".to_string()), "`class`"); + assert_eq!(kotlin_ident("when".to_string()), "`when`"); + assert_eq!(kotlin_ident("val".to_string()), "`val`"); + } +} diff --git a/crates/codegen/src/lib.rs b/crates/codegen/src/lib.rs index 28d4fb8a5a4..ed84bca0e7a 100644 --- a/crates/codegen/src/lib.rs +++ b/crates/codegen/src/lib.rs @@ -3,12 +3,14 @@ use spacetimedb_schema::schema::{Schema, TableSchema}; mod code_indenter; pub mod cpp; pub mod csharp; +pub mod kotlin; pub mod rust; pub mod typescript; pub mod unrealcpp; mod util; pub use self::csharp::Csharp; +pub use self::kotlin::Kotlin; pub use self::rust::Rust; pub use self::typescript::TypeScript; pub use self::unrealcpp::UnrealCpp; diff --git a/crates/codegen/tests/codegen.rs b/crates/codegen/tests/codegen.rs index 06dc3ebe8fc..5ff30b496be 100644 --- a/crates/codegen/tests/codegen.rs +++ b/crates/codegen/tests/codegen.rs @@ -1,4 +1,4 @@ -use spacetimedb_codegen::{generate, CodegenOptions, Csharp, Rust, TypeScript}; +use spacetimedb_codegen::{generate, kotlin::Kotlin, CodegenOptions, Csharp, Rust, TypeScript}; use spacetimedb_data_structures::map::HashMap; use spacetimedb_schema::def::ModuleDef; use spacetimedb_testing::modules::{CompilationMode, CompiledModule}; @@ -36,6 +36,7 @@ macro_rules! declare_tests { declare_tests! { test_codegen_csharp => Csharp { namespace: "SpacetimeDB" }, - test_codegen_typescript => TypeScript, + test_codegen_kotlin => Kotlin, test_codegen_rust => Rust, + test_codegen_typescript => TypeScript, } diff --git a/crates/codegen/tests/snapshots/codegen__codegen_kotlin.snap b/crates/codegen/tests/snapshots/codegen__codegen_kotlin.snap new file mode 100644 index 00000000000..e6a21b3ecaa --- /dev/null +++ b/crates/codegen/tests/snapshots/codegen__codegen_kotlin.snap @@ -0,0 +1,1888 @@ +--- +source: crates/codegen/tests/codegen.rs +expression: outfiles +--- +"AddPlayerReducer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Arguments for the `add_player` reducer. */ +data class AddPlayerArgs( + val name: String +) { + /** Encodes these arguments to BSATN. */ + fun encode(): ByteArray { + val writer = BsatnWriter() + writer.writeString(name) + return writer.toByteArray() + } + + companion object { + /** Decodes [AddPlayerArgs] from BSATN. */ + fun decode(reader: BsatnReader): AddPlayerArgs { + val name = reader.readString() + return AddPlayerArgs(name) + } + } +} + +/** Constants for the `add_player` reducer. */ +object AddPlayerReducer { + const val REDUCER_NAME = "add_player" +} +''' +"AddPrivateReducer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Arguments for the `add_private` reducer. */ +data class AddPrivateArgs( + val name: String +) { + /** Encodes these arguments to BSATN. */ + fun encode(): ByteArray { + val writer = BsatnWriter() + writer.writeString(name) + return writer.toByteArray() + } + + companion object { + /** Decodes [AddPrivateArgs] from BSATN. */ + fun decode(reader: BsatnReader): AddPrivateArgs { + val name = reader.readString() + return AddPrivateArgs(name) + } + } +} + +/** Constants for the `add_private` reducer. */ +object AddPrivateReducer { + const val REDUCER_NAME = "add_private" +} +''' +"AddReducer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Arguments for the `add` reducer. */ +data class AddArgs( + val name: String, + val age: UByte +) { + /** Encodes these arguments to BSATN. */ + fun encode(): ByteArray { + val writer = BsatnWriter() + writer.writeString(name) + writer.writeU8(age) + return writer.toByteArray() + } + + companion object { + /** Decodes [AddArgs] from BSATN. */ + fun decode(reader: BsatnReader): AddArgs { + val name = reader.readString() + val age = reader.readU8() + return AddArgs(name, age) + } + } +} + +/** Constants for the `add` reducer. */ +object AddReducer { + const val REDUCER_NAME = "add" +} +''' +"AssertCallerIdentityIsModuleIdentityReducer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Constants for the `assert_caller_identity_is_module_identity` reducer. */ +object AssertCallerIdentityIsModuleIdentityReducer { + const val REDUCER_NAME = "assert_caller_identity_is_module_identity" +} +''' +"DeletePlayerReducer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Arguments for the `delete_player` reducer. */ +data class DeletePlayerArgs( + val id: ULong +) { + /** Encodes these arguments to BSATN. */ + fun encode(): ByteArray { + val writer = BsatnWriter() + writer.writeU64(id) + return writer.toByteArray() + } + + companion object { + /** Decodes [DeletePlayerArgs] from BSATN. */ + fun decode(reader: BsatnReader): DeletePlayerArgs { + val id = reader.readU64() + return DeletePlayerArgs(id) + } + } +} + +/** Constants for the `delete_player` reducer. */ +object DeletePlayerReducer { + const val REDUCER_NAME = "delete_player" +} +''' +"DeletePlayersByNameReducer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Arguments for the `delete_players_by_name` reducer. */ +data class DeletePlayersByNameArgs( + val name: String +) { + /** Encodes these arguments to BSATN. */ + fun encode(): ByteArray { + val writer = BsatnWriter() + writer.writeString(name) + return writer.toByteArray() + } + + companion object { + /** Decodes [DeletePlayersByNameArgs] from BSATN. */ + fun decode(reader: BsatnReader): DeletePlayersByNameArgs { + val name = reader.readString() + return DeletePlayersByNameArgs(name) + } + } +} + +/** Constants for the `delete_players_by_name` reducer. */ +object DeletePlayersByNameReducer { + const val REDUCER_NAME = "delete_players_by_name" +} +''' +"GetMySchemaViaHttpProcedure.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +object GetMySchemaViaHttpProcedure { + const val PROCEDURE_NAME = "get_my_schema_via_http" + // Returns: String +} +''' +"ListOverAgeReducer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Arguments for the `list_over_age` reducer. */ +data class ListOverAgeArgs( + val age: UByte +) { + /** Encodes these arguments to BSATN. */ + fun encode(): ByteArray { + val writer = BsatnWriter() + writer.writeU8(age) + return writer.toByteArray() + } + + companion object { + /** Decodes [ListOverAgeArgs] from BSATN. */ + fun decode(reader: BsatnReader): ListOverAgeArgs { + val age = reader.readU8() + return ListOverAgeArgs(age) + } + } +} + +/** Constants for the `list_over_age` reducer. */ +object ListOverAgeReducer { + const val REDUCER_NAME = "list_over_age" +} +''' +"LogModuleIdentityReducer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Constants for the `log_module_identity` reducer. */ +object LogModuleIdentityReducer { + const val REDUCER_NAME = "log_module_identity" +} +''' +"LoggedOutPlayerTableHandle.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Col +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.IxCol +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.RemotePersistentTableWithPrimaryKey +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.TableCache +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.UniqueIndex +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QueryResult +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity + +/** Client-side handle for the `logged_out_player` table. */ +@OptIn(InternalSpacetimeApi::class) +class LoggedOutPlayerTableHandle internal constructor( + private val conn: DbConnection, + private val tableCache: TableCache, +) : RemotePersistentTableWithPrimaryKey { + companion object { + const val TABLE_NAME = "logged_out_player" + + const val FIELD_IDENTITY = "identity" + const val FIELD_PLAYER_ID = "player_id" + const val FIELD_NAME = "name" + + fun createTableCache(): TableCache { + return TableCache.withPrimaryKey({ reader -> Player.decode(reader) }) { row -> row.identity } + } + } + + override fun count(): Int = tableCache.count() + override fun all(): List = tableCache.all() + override fun iter(): Sequence = tableCache.iter() + + override fun onInsert(cb: (EventContext, Player) -> Unit) { tableCache.onInsert(cb) } + override fun removeOnInsert(cb: (EventContext, Player) -> Unit) { tableCache.removeOnInsert(cb) } + override fun onDelete(cb: (EventContext, Player) -> Unit) { tableCache.onDelete(cb) } + override fun onUpdate(cb: (EventContext, Player, Player) -> Unit) { tableCache.onUpdate(cb) } + override fun onBeforeDelete(cb: (EventContext, Player) -> Unit) { tableCache.onBeforeDelete(cb) } + + override fun removeOnDelete(cb: (EventContext, Player) -> Unit) { tableCache.removeOnDelete(cb) } + override fun removeOnUpdate(cb: (EventContext, Player, Player) -> Unit) { tableCache.removeOnUpdate(cb) } + override fun removeOnBeforeDelete(cb: (EventContext, Player) -> Unit) { tableCache.removeOnBeforeDelete(cb) } + + val identity = UniqueIndex(tableCache) { it.identity } + + val name = UniqueIndex(tableCache) { it.name } + + val playerId = UniqueIndex(tableCache) { it.playerId } + +} + +@OptIn(InternalSpacetimeApi::class) +class LoggedOutPlayerCols(tableName: String) { + val identity = Col(tableName, "identity") + val playerId = Col(tableName, "player_id") + val name = Col(tableName, "name") +} + +@OptIn(InternalSpacetimeApi::class) +class LoggedOutPlayerIxCols(tableName: String) { + val identity = IxCol(tableName, "identity") + val playerId = IxCol(tableName, "player_id") + val name = IxCol(tableName, "name") +} +''' +"Module.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings +VERSION_COMMENT + + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ClientCache +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnectionView +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ModuleAccessors +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ModuleDescriptor +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Query +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SubscriptionBuilder +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Table + +/** + * Module metadata generated by the SpacetimeDB CLI. + * Contains version info and the names of all tables, reducers, and procedures. + */ +@OptIn(InternalSpacetimeApi::class) +object RemoteModule : ModuleDescriptor { + override val cliVersion: String = "2.1.0" + + val tableNames: List = listOf( + "logged_out_player", + "my_player", + "person", + "player", + "test_d", + "test_f", + ) + + override val subscribableTableNames: List = listOf( + "logged_out_player", + "my_player", + "person", + "player", + "test_d", + "test_f", + ) + + val reducerNames: List = listOf( + "add", + "add_player", + "add_private", + "assert_caller_identity_is_module_identity", + "delete_player", + "delete_players_by_name", + "list_over_age", + "log_module_identity", + "query_private", + "say_hello", + "test", + "test_btree_index_args", + ) + + val procedureNames: List = listOf( + "get_my_schema_via_http", + "return_value", + "sleep_one_second", + "with_tx", + ) + + override fun registerTables(cache: ClientCache) { + cache.register(LoggedOutPlayerTableHandle.TABLE_NAME, LoggedOutPlayerTableHandle.createTableCache()) + cache.register(MyPlayerTableHandle.TABLE_NAME, MyPlayerTableHandle.createTableCache()) + cache.register(PersonTableHandle.TABLE_NAME, PersonTableHandle.createTableCache()) + cache.register(PlayerTableHandle.TABLE_NAME, PlayerTableHandle.createTableCache()) + cache.register(TestDTableHandle.TABLE_NAME, TestDTableHandle.createTableCache()) + cache.register(TestFTableHandle.TABLE_NAME, TestFTableHandle.createTableCache()) + } + + override fun createAccessors(conn: DbConnection): ModuleAccessors { + return ModuleAccessors( + tables = RemoteTables(conn, conn.clientCache), + reducers = RemoteReducers(conn), + procedures = RemoteProcedures(conn), + ) + } + + override fun handleReducerEvent(conn: DbConnection, ctx: EventContext.Reducer<*>) { + conn.reducers.handleReducerEvent(ctx) + } +} + +/** + * Typed table accessors for this module's tables. + */ +val DbConnection.db: RemoteTables + get() = moduleTables as RemoteTables + +/** + * Typed reducer call functions for this module's reducers. + */ +val DbConnection.reducers: RemoteReducers + get() = moduleReducers as RemoteReducers + +/** + * Typed procedure call functions for this module's procedures. + */ +val DbConnection.procedures: RemoteProcedures + get() = moduleProcedures as RemoteProcedures + +/** + * Typed table accessors for this module's tables. + */ +val DbConnectionView.db: RemoteTables + get() = moduleTables as RemoteTables + +/** + * Typed reducer call functions for this module's reducers. + */ +val DbConnectionView.reducers: RemoteReducers + get() = moduleReducers as RemoteReducers + +/** + * Typed procedure call functions for this module's procedures. + */ +val DbConnectionView.procedures: RemoteProcedures + get() = moduleProcedures as RemoteProcedures + +/** + * Typed table accessors available directly on event context. + */ +val EventContext.db: RemoteTables + get() = connection.db + +/** + * Typed reducer call functions available directly on event context. + */ +val EventContext.reducers: RemoteReducers + get() = connection.reducers + +/** + * Typed procedure call functions available directly on event context. + */ +val EventContext.procedures: RemoteProcedures + get() = connection.procedures + +/** + * Registers this module's tables with the connection builder. + * Call this on the builder to enable typed [db], [reducers], and [procedures] accessors. + * + * Example: + * ```kotlin + * val conn = DbConnection.Builder() + * .withUri("ws://localhost:3000") + * .withDatabaseName("my_module") + * .withModuleBindings() + * .build() + * ``` + */ +@OptIn(InternalSpacetimeApi::class) +fun DbConnection.Builder.withModuleBindings(): DbConnection.Builder { + return withModule(RemoteModule) +} + +/** + * Type-safe query builder for this module's tables. + * Supports WHERE predicates and semi-joins. + */ +class QueryBuilder { + fun loggedOutPlayer(): Table = Table("logged_out_player", LoggedOutPlayerCols("logged_out_player"), LoggedOutPlayerIxCols("logged_out_player")) + fun myPlayer(): Table = Table("my_player", MyPlayerCols("my_player"), MyPlayerIxCols()) + fun person(): Table = Table("person", PersonCols("person"), PersonIxCols("person")) + fun player(): Table = Table("player", PlayerCols("player"), PlayerIxCols("player")) + fun testD(): Table = Table("test_d", TestDCols("test_d"), TestDIxCols()) + fun testF(): Table = Table("test_f", TestFCols("test_f"), TestFIxCols()) +} + +/** + * Add a type-safe table query to this subscription. + * + * Example: + * ```kotlin + * conn.subscriptionBuilder() + * .addQuery { qb -> qb.player() } + * .addQuery { qb -> qb.player().where { c -> c.health.gt(50) } } + * .subscribe() + * ``` + */ +fun SubscriptionBuilder.addQuery(build: (QueryBuilder) -> Query<*>): SubscriptionBuilder { + return addQuery(build(QueryBuilder()).toSql()) +} + +/** + * Subscribe to all persistent tables in this module. + * Event tables are excluded because the server does not support subscribing to them. + */ +fun SubscriptionBuilder.subscribeToAllTables(): com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SubscriptionHandle { + val qb = QueryBuilder() + addQuery(qb.loggedOutPlayer().toSql()) + addQuery(qb.myPlayer().toSql()) + addQuery(qb.person().toSql()) + addQuery(qb.player().toSql()) + addQuery(qb.testD().toSql()) + addQuery(qb.testF().toSql()) + return subscribe() +} +''' +"MyPlayerTableHandle.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Col +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.RemotePersistentTable +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.TableCache +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QueryResult +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity + +/** Client-side handle for the `my_player` table. */ +@OptIn(InternalSpacetimeApi::class) +class MyPlayerTableHandle internal constructor( + private val conn: DbConnection, + private val tableCache: TableCache, +) : RemotePersistentTable { + companion object { + const val TABLE_NAME = "my_player" + + const val FIELD_IDENTITY = "identity" + const val FIELD_PLAYER_ID = "player_id" + const val FIELD_NAME = "name" + + fun createTableCache(): TableCache { + return TableCache.withContentKey { reader -> Player.decode(reader) } + } + } + + override fun count(): Int = tableCache.count() + override fun all(): List = tableCache.all() + override fun iter(): Sequence = tableCache.iter() + + override fun onInsert(cb: (EventContext, Player) -> Unit) { tableCache.onInsert(cb) } + override fun removeOnInsert(cb: (EventContext, Player) -> Unit) { tableCache.removeOnInsert(cb) } + override fun onDelete(cb: (EventContext, Player) -> Unit) { tableCache.onDelete(cb) } + override fun onBeforeDelete(cb: (EventContext, Player) -> Unit) { tableCache.onBeforeDelete(cb) } + + override fun removeOnDelete(cb: (EventContext, Player) -> Unit) { tableCache.removeOnDelete(cb) } + override fun removeOnBeforeDelete(cb: (EventContext, Player) -> Unit) { tableCache.removeOnBeforeDelete(cb) } + +} + +@OptIn(InternalSpacetimeApi::class) +class MyPlayerCols(tableName: String) { + val identity = Col(tableName, "identity") + val playerId = Col(tableName, "player_id") + val name = Col(tableName, "name") +} + +class MyPlayerIxCols +''' +"PersonTableHandle.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.BTreeIndex +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Col +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.IxCol +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.RemotePersistentTableWithPrimaryKey +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.TableCache +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.UniqueIndex +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QueryResult + +/** Client-side handle for the `person` table. */ +@OptIn(InternalSpacetimeApi::class) +class PersonTableHandle internal constructor( + private val conn: DbConnection, + private val tableCache: TableCache, +) : RemotePersistentTableWithPrimaryKey { + companion object { + const val TABLE_NAME = "person" + + const val FIELD_ID = "id" + const val FIELD_NAME = "name" + const val FIELD_AGE = "age" + + fun createTableCache(): TableCache { + return TableCache.withPrimaryKey({ reader -> Person.decode(reader) }) { row -> row.id } + } + } + + override fun count(): Int = tableCache.count() + override fun all(): List = tableCache.all() + override fun iter(): Sequence = tableCache.iter() + + override fun onInsert(cb: (EventContext, Person) -> Unit) { tableCache.onInsert(cb) } + override fun removeOnInsert(cb: (EventContext, Person) -> Unit) { tableCache.removeOnInsert(cb) } + override fun onDelete(cb: (EventContext, Person) -> Unit) { tableCache.onDelete(cb) } + override fun onUpdate(cb: (EventContext, Person, Person) -> Unit) { tableCache.onUpdate(cb) } + override fun onBeforeDelete(cb: (EventContext, Person) -> Unit) { tableCache.onBeforeDelete(cb) } + + override fun removeOnDelete(cb: (EventContext, Person) -> Unit) { tableCache.removeOnDelete(cb) } + override fun removeOnUpdate(cb: (EventContext, Person, Person) -> Unit) { tableCache.removeOnUpdate(cb) } + override fun removeOnBeforeDelete(cb: (EventContext, Person) -> Unit) { tableCache.removeOnBeforeDelete(cb) } + + val age = BTreeIndex(tableCache) { it.age } + + val id = UniqueIndex(tableCache) { it.id } + +} + +@OptIn(InternalSpacetimeApi::class) +class PersonCols(tableName: String) { + val id = Col(tableName, "id") + val name = Col(tableName, "name") + val age = Col(tableName, "age") +} + +@OptIn(InternalSpacetimeApi::class) +class PersonIxCols(tableName: String) { + val id = IxCol(tableName, "id") + val age = IxCol(tableName, "age") +} +''' +"PlayerTableHandle.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Col +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.IxCol +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.RemotePersistentTableWithPrimaryKey +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.TableCache +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.UniqueIndex +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QueryResult +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity + +/** Client-side handle for the `player` table. */ +@OptIn(InternalSpacetimeApi::class) +class PlayerTableHandle internal constructor( + private val conn: DbConnection, + private val tableCache: TableCache, +) : RemotePersistentTableWithPrimaryKey { + companion object { + const val TABLE_NAME = "player" + + const val FIELD_IDENTITY = "identity" + const val FIELD_PLAYER_ID = "player_id" + const val FIELD_NAME = "name" + + fun createTableCache(): TableCache { + return TableCache.withPrimaryKey({ reader -> Player.decode(reader) }) { row -> row.identity } + } + } + + override fun count(): Int = tableCache.count() + override fun all(): List = tableCache.all() + override fun iter(): Sequence = tableCache.iter() + + override fun onInsert(cb: (EventContext, Player) -> Unit) { tableCache.onInsert(cb) } + override fun removeOnInsert(cb: (EventContext, Player) -> Unit) { tableCache.removeOnInsert(cb) } + override fun onDelete(cb: (EventContext, Player) -> Unit) { tableCache.onDelete(cb) } + override fun onUpdate(cb: (EventContext, Player, Player) -> Unit) { tableCache.onUpdate(cb) } + override fun onBeforeDelete(cb: (EventContext, Player) -> Unit) { tableCache.onBeforeDelete(cb) } + + override fun removeOnDelete(cb: (EventContext, Player) -> Unit) { tableCache.removeOnDelete(cb) } + override fun removeOnUpdate(cb: (EventContext, Player, Player) -> Unit) { tableCache.removeOnUpdate(cb) } + override fun removeOnBeforeDelete(cb: (EventContext, Player) -> Unit) { tableCache.removeOnBeforeDelete(cb) } + + val identity = UniqueIndex(tableCache) { it.identity } + + val name = UniqueIndex(tableCache) { it.name } + + val playerId = UniqueIndex(tableCache) { it.playerId } + +} + +@OptIn(InternalSpacetimeApi::class) +class PlayerCols(tableName: String) { + val identity = Col(tableName, "identity") + val playerId = Col(tableName, "player_id") + val name = Col(tableName, "name") +} + +@OptIn(InternalSpacetimeApi::class) +class PlayerIxCols(tableName: String) { + val identity = IxCol(tableName, "identity") + val playerId = IxCol(tableName, "player_id") + val name = IxCol(tableName, "name") +} +''' +"QueryPrivateReducer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Constants for the `query_private` reducer. */ +object QueryPrivateReducer { + const val REDUCER_NAME = "query_private" +} +''' +"RemoteProcedures.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ModuleProcedures +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ProcedureError +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SdkResult +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ProcedureStatus +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ServerMessage + +/** Generated procedure call methods and callback registration. */ +@OptIn(InternalSpacetimeApi::class) +class RemoteProcedures internal constructor( + private val conn: DbConnection, +) : ModuleProcedures { + fun getMySchemaViaHttp(callback: ((EventContext.Procedure, SdkResult) -> Unit)? = null) { + val wrappedCallback = callback?.let { userCb -> + { ctx: EventContext.Procedure, msg: ServerMessage.ProcedureResultMsg -> + when (val status = msg.status) { + is ProcedureStatus.Returned -> { + val reader = BsatnReader(status.value) + userCb(ctx, SdkResult.Success(reader.readString())) + } + is ProcedureStatus.InternalError -> { + userCb(ctx, SdkResult.Failure(ProcedureError.InternalError(status.message))) + } + } + } + } + conn.callProcedure(GetMySchemaViaHttpProcedure.PROCEDURE_NAME, ByteArray(0), wrappedCallback) + } + + fun returnValue(foo: ULong, callback: ((EventContext.Procedure, SdkResult) -> Unit)? = null) { + val args = ReturnValueArgs(foo) + val wrappedCallback = callback?.let { userCb -> + { ctx: EventContext.Procedure, msg: ServerMessage.ProcedureResultMsg -> + when (val status = msg.status) { + is ProcedureStatus.Returned -> { + val reader = BsatnReader(status.value) + userCb(ctx, SdkResult.Success(Baz.decode(reader))) + } + is ProcedureStatus.InternalError -> { + userCb(ctx, SdkResult.Failure(ProcedureError.InternalError(status.message))) + } + } + } + } + conn.callProcedure(ReturnValueProcedure.PROCEDURE_NAME, args.encode(), wrappedCallback) + } + + fun sleepOneSecond(callback: ((EventContext.Procedure, SdkResult) -> Unit)? = null) { + val wrappedCallback = callback?.let { userCb -> + { ctx: EventContext.Procedure, msg: ServerMessage.ProcedureResultMsg -> + when (val status = msg.status) { + is ProcedureStatus.Returned -> { + userCb(ctx, SdkResult.Success(Unit)) + } + is ProcedureStatus.InternalError -> { + userCb(ctx, SdkResult.Failure(ProcedureError.InternalError(status.message))) + } + } + } + } + conn.callProcedure(SleepOneSecondProcedure.PROCEDURE_NAME, ByteArray(0), wrappedCallback) + } + + fun withTx(callback: ((EventContext.Procedure, SdkResult) -> Unit)? = null) { + val wrappedCallback = callback?.let { userCb -> + { ctx: EventContext.Procedure, msg: ServerMessage.ProcedureResultMsg -> + when (val status = msg.status) { + is ProcedureStatus.Returned -> { + userCb(ctx, SdkResult.Success(Unit)) + } + is ProcedureStatus.InternalError -> { + userCb(ctx, SdkResult.Failure(ProcedureError.InternalError(status.message))) + } + } + } + } + conn.callProcedure(WithTxProcedure.PROCEDURE_NAME, ByteArray(0), wrappedCallback) + } + +} +''' +"RemoteReducers.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.CallbackList +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ModuleReducers +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Status + +/** Generated reducer call methods and callback registration. */ +@OptIn(InternalSpacetimeApi::class) +class RemoteReducers internal constructor( + private val conn: DbConnection, +) : ModuleReducers { + fun add(name: String, age: UByte, callback: ((EventContext.Reducer) -> Unit)? = null) { + val args = AddArgs(name, age) + conn.callReducer(AddReducer.REDUCER_NAME, args.encode(), args, callback) + } + + fun addPlayer(name: String, callback: ((EventContext.Reducer) -> Unit)? = null) { + val args = AddPlayerArgs(name) + conn.callReducer(AddPlayerReducer.REDUCER_NAME, args.encode(), args, callback) + } + + fun addPrivate(name: String, callback: ((EventContext.Reducer) -> Unit)? = null) { + val args = AddPrivateArgs(name) + conn.callReducer(AddPrivateReducer.REDUCER_NAME, args.encode(), args, callback) + } + + fun assertCallerIdentityIsModuleIdentity(callback: ((EventContext.Reducer) -> Unit)? = null) { + conn.callReducer(AssertCallerIdentityIsModuleIdentityReducer.REDUCER_NAME, ByteArray(0), Unit, callback) + } + + fun deletePlayer(id: ULong, callback: ((EventContext.Reducer) -> Unit)? = null) { + val args = DeletePlayerArgs(id) + conn.callReducer(DeletePlayerReducer.REDUCER_NAME, args.encode(), args, callback) + } + + fun deletePlayersByName(name: String, callback: ((EventContext.Reducer) -> Unit)? = null) { + val args = DeletePlayersByNameArgs(name) + conn.callReducer(DeletePlayersByNameReducer.REDUCER_NAME, args.encode(), args, callback) + } + + fun listOverAge(age: UByte, callback: ((EventContext.Reducer) -> Unit)? = null) { + val args = ListOverAgeArgs(age) + conn.callReducer(ListOverAgeReducer.REDUCER_NAME, args.encode(), args, callback) + } + + fun logModuleIdentity(callback: ((EventContext.Reducer) -> Unit)? = null) { + conn.callReducer(LogModuleIdentityReducer.REDUCER_NAME, ByteArray(0), Unit, callback) + } + + fun queryPrivate(callback: ((EventContext.Reducer) -> Unit)? = null) { + conn.callReducer(QueryPrivateReducer.REDUCER_NAME, ByteArray(0), Unit, callback) + } + + fun sayHello(callback: ((EventContext.Reducer) -> Unit)? = null) { + conn.callReducer(SayHelloReducer.REDUCER_NAME, ByteArray(0), Unit, callback) + } + + fun test(arg: TestA, arg2: TestB, arg3: NamespaceTestC, arg4: NamespaceTestF, callback: ((EventContext.Reducer) -> Unit)? = null) { + val args = TestArgs(arg, arg2, arg3, arg4) + conn.callReducer(TestReducer.REDUCER_NAME, args.encode(), args, callback) + } + + fun testBtreeIndexArgs(callback: ((EventContext.Reducer) -> Unit)? = null) { + conn.callReducer(TestBtreeIndexArgsReducer.REDUCER_NAME, ByteArray(0), Unit, callback) + } + + private val onAddCallbacks = CallbackList<(EventContext.Reducer, String, UByte) -> Unit>() + + fun onAdd(cb: (EventContext.Reducer, String, UByte) -> Unit) { + onAddCallbacks.add(cb) + } + + fun removeOnAdd(cb: (EventContext.Reducer, String, UByte) -> Unit) { + onAddCallbacks.remove(cb) + } + + private val onAddPlayerCallbacks = CallbackList<(EventContext.Reducer, String) -> Unit>() + + fun onAddPlayer(cb: (EventContext.Reducer, String) -> Unit) { + onAddPlayerCallbacks.add(cb) + } + + fun removeOnAddPlayer(cb: (EventContext.Reducer, String) -> Unit) { + onAddPlayerCallbacks.remove(cb) + } + + private val onAddPrivateCallbacks = CallbackList<(EventContext.Reducer, String) -> Unit>() + + fun onAddPrivate(cb: (EventContext.Reducer, String) -> Unit) { + onAddPrivateCallbacks.add(cb) + } + + fun removeOnAddPrivate(cb: (EventContext.Reducer, String) -> Unit) { + onAddPrivateCallbacks.remove(cb) + } + + private val onAssertCallerIdentityIsModuleIdentityCallbacks = CallbackList<(EventContext.Reducer) -> Unit>() + + fun onAssertCallerIdentityIsModuleIdentity(cb: (EventContext.Reducer) -> Unit) { + onAssertCallerIdentityIsModuleIdentityCallbacks.add(cb) + } + + fun removeOnAssertCallerIdentityIsModuleIdentity(cb: (EventContext.Reducer) -> Unit) { + onAssertCallerIdentityIsModuleIdentityCallbacks.remove(cb) + } + + private val onDeletePlayerCallbacks = CallbackList<(EventContext.Reducer, ULong) -> Unit>() + + fun onDeletePlayer(cb: (EventContext.Reducer, ULong) -> Unit) { + onDeletePlayerCallbacks.add(cb) + } + + fun removeOnDeletePlayer(cb: (EventContext.Reducer, ULong) -> Unit) { + onDeletePlayerCallbacks.remove(cb) + } + + private val onDeletePlayersByNameCallbacks = CallbackList<(EventContext.Reducer, String) -> Unit>() + + fun onDeletePlayersByName(cb: (EventContext.Reducer, String) -> Unit) { + onDeletePlayersByNameCallbacks.add(cb) + } + + fun removeOnDeletePlayersByName(cb: (EventContext.Reducer, String) -> Unit) { + onDeletePlayersByNameCallbacks.remove(cb) + } + + private val onListOverAgeCallbacks = CallbackList<(EventContext.Reducer, UByte) -> Unit>() + + fun onListOverAge(cb: (EventContext.Reducer, UByte) -> Unit) { + onListOverAgeCallbacks.add(cb) + } + + fun removeOnListOverAge(cb: (EventContext.Reducer, UByte) -> Unit) { + onListOverAgeCallbacks.remove(cb) + } + + private val onLogModuleIdentityCallbacks = CallbackList<(EventContext.Reducer) -> Unit>() + + fun onLogModuleIdentity(cb: (EventContext.Reducer) -> Unit) { + onLogModuleIdentityCallbacks.add(cb) + } + + fun removeOnLogModuleIdentity(cb: (EventContext.Reducer) -> Unit) { + onLogModuleIdentityCallbacks.remove(cb) + } + + private val onQueryPrivateCallbacks = CallbackList<(EventContext.Reducer) -> Unit>() + + fun onQueryPrivate(cb: (EventContext.Reducer) -> Unit) { + onQueryPrivateCallbacks.add(cb) + } + + fun removeOnQueryPrivate(cb: (EventContext.Reducer) -> Unit) { + onQueryPrivateCallbacks.remove(cb) + } + + private val onSayHelloCallbacks = CallbackList<(EventContext.Reducer) -> Unit>() + + fun onSayHello(cb: (EventContext.Reducer) -> Unit) { + onSayHelloCallbacks.add(cb) + } + + fun removeOnSayHello(cb: (EventContext.Reducer) -> Unit) { + onSayHelloCallbacks.remove(cb) + } + + private val onTestCallbacks = CallbackList<(EventContext.Reducer, TestA, TestB, NamespaceTestC, NamespaceTestF) -> Unit>() + + fun onTest(cb: (EventContext.Reducer, TestA, TestB, NamespaceTestC, NamespaceTestF) -> Unit) { + onTestCallbacks.add(cb) + } + + fun removeOnTest(cb: (EventContext.Reducer, TestA, TestB, NamespaceTestC, NamespaceTestF) -> Unit) { + onTestCallbacks.remove(cb) + } + + private val onTestBtreeIndexArgsCallbacks = CallbackList<(EventContext.Reducer) -> Unit>() + + fun onTestBtreeIndexArgs(cb: (EventContext.Reducer) -> Unit) { + onTestBtreeIndexArgsCallbacks.add(cb) + } + + fun removeOnTestBtreeIndexArgs(cb: (EventContext.Reducer) -> Unit) { + onTestBtreeIndexArgsCallbacks.remove(cb) + } + + private val onUnhandledReducerErrorCallbacks = CallbackList<(EventContext.Reducer<*>) -> Unit>() + + /** Register a callback for reducer errors with no specific handler. */ + fun onUnhandledReducerError(cb: (EventContext.Reducer<*>) -> Unit) { + onUnhandledReducerErrorCallbacks.add(cb) + } + + fun removeOnUnhandledReducerError(cb: (EventContext.Reducer<*>) -> Unit) { + onUnhandledReducerErrorCallbacks.remove(cb) + } + + internal fun handleReducerEvent(ctx: EventContext.Reducer<*>) { + when (ctx.reducerName) { + AddReducer.REDUCER_NAME -> { + if (onAddCallbacks.isNotEmpty()) { + @Suppress("UNCHECKED_CAST") + val typedCtx = ctx as EventContext.Reducer + onAddCallbacks.forEach { it(typedCtx, typedCtx.args.name, typedCtx.args.age) } + } else if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + AddPlayerReducer.REDUCER_NAME -> { + if (onAddPlayerCallbacks.isNotEmpty()) { + @Suppress("UNCHECKED_CAST") + val typedCtx = ctx as EventContext.Reducer + onAddPlayerCallbacks.forEach { it(typedCtx, typedCtx.args.name) } + } else if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + AddPrivateReducer.REDUCER_NAME -> { + if (onAddPrivateCallbacks.isNotEmpty()) { + @Suppress("UNCHECKED_CAST") + val typedCtx = ctx as EventContext.Reducer + onAddPrivateCallbacks.forEach { it(typedCtx, typedCtx.args.name) } + } else if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + AssertCallerIdentityIsModuleIdentityReducer.REDUCER_NAME -> { + if (onAssertCallerIdentityIsModuleIdentityCallbacks.isNotEmpty()) { + @Suppress("UNCHECKED_CAST") + val typedCtx = ctx as EventContext.Reducer + onAssertCallerIdentityIsModuleIdentityCallbacks.forEach { it(typedCtx) } + } else if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + DeletePlayerReducer.REDUCER_NAME -> { + if (onDeletePlayerCallbacks.isNotEmpty()) { + @Suppress("UNCHECKED_CAST") + val typedCtx = ctx as EventContext.Reducer + onDeletePlayerCallbacks.forEach { it(typedCtx, typedCtx.args.id) } + } else if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + DeletePlayersByNameReducer.REDUCER_NAME -> { + if (onDeletePlayersByNameCallbacks.isNotEmpty()) { + @Suppress("UNCHECKED_CAST") + val typedCtx = ctx as EventContext.Reducer + onDeletePlayersByNameCallbacks.forEach { it(typedCtx, typedCtx.args.name) } + } else if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + ListOverAgeReducer.REDUCER_NAME -> { + if (onListOverAgeCallbacks.isNotEmpty()) { + @Suppress("UNCHECKED_CAST") + val typedCtx = ctx as EventContext.Reducer + onListOverAgeCallbacks.forEach { it(typedCtx, typedCtx.args.age) } + } else if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + LogModuleIdentityReducer.REDUCER_NAME -> { + if (onLogModuleIdentityCallbacks.isNotEmpty()) { + @Suppress("UNCHECKED_CAST") + val typedCtx = ctx as EventContext.Reducer + onLogModuleIdentityCallbacks.forEach { it(typedCtx) } + } else if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + QueryPrivateReducer.REDUCER_NAME -> { + if (onQueryPrivateCallbacks.isNotEmpty()) { + @Suppress("UNCHECKED_CAST") + val typedCtx = ctx as EventContext.Reducer + onQueryPrivateCallbacks.forEach { it(typedCtx) } + } else if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + SayHelloReducer.REDUCER_NAME -> { + if (onSayHelloCallbacks.isNotEmpty()) { + @Suppress("UNCHECKED_CAST") + val typedCtx = ctx as EventContext.Reducer + onSayHelloCallbacks.forEach { it(typedCtx) } + } else if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + TestReducer.REDUCER_NAME -> { + if (onTestCallbacks.isNotEmpty()) { + @Suppress("UNCHECKED_CAST") + val typedCtx = ctx as EventContext.Reducer + onTestCallbacks.forEach { it(typedCtx, typedCtx.args.arg, typedCtx.args.arg2, typedCtx.args.arg3, typedCtx.args.arg4) } + } else if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + TestBtreeIndexArgsReducer.REDUCER_NAME -> { + if (onTestBtreeIndexArgsCallbacks.isNotEmpty()) { + @Suppress("UNCHECKED_CAST") + val typedCtx = ctx as EventContext.Reducer + onTestBtreeIndexArgsCallbacks.forEach { it(typedCtx) } + } else if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + else -> { + if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + } + } +} +''' +"RemoteTables.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ClientCache +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ModuleTables + +/** Generated table accessors for all tables in this module. */ +@OptIn(InternalSpacetimeApi::class) +class RemoteTables internal constructor( + private val conn: DbConnection, + private val clientCache: ClientCache, +) : ModuleTables { + val loggedOutPlayer: LoggedOutPlayerTableHandle by lazy { + @Suppress("UNCHECKED_CAST") + val cache = clientCache.getOrCreateTable(LoggedOutPlayerTableHandle.TABLE_NAME) { + LoggedOutPlayerTableHandle.createTableCache() + } + LoggedOutPlayerTableHandle(conn, cache) + } + + val myPlayer: MyPlayerTableHandle by lazy { + @Suppress("UNCHECKED_CAST") + val cache = clientCache.getOrCreateTable(MyPlayerTableHandle.TABLE_NAME) { + MyPlayerTableHandle.createTableCache() + } + MyPlayerTableHandle(conn, cache) + } + + val person: PersonTableHandle by lazy { + @Suppress("UNCHECKED_CAST") + val cache = clientCache.getOrCreateTable(PersonTableHandle.TABLE_NAME) { + PersonTableHandle.createTableCache() + } + PersonTableHandle(conn, cache) + } + + val player: PlayerTableHandle by lazy { + @Suppress("UNCHECKED_CAST") + val cache = clientCache.getOrCreateTable(PlayerTableHandle.TABLE_NAME) { + PlayerTableHandle.createTableCache() + } + PlayerTableHandle(conn, cache) + } + + val testD: TestDTableHandle by lazy { + @Suppress("UNCHECKED_CAST") + val cache = clientCache.getOrCreateTable(TestDTableHandle.TABLE_NAME) { + TestDTableHandle.createTableCache() + } + TestDTableHandle(conn, cache) + } + + val testF: TestFTableHandle by lazy { + @Suppress("UNCHECKED_CAST") + val cache = clientCache.getOrCreateTable(TestFTableHandle.TABLE_NAME) { + TestFTableHandle.createTableCache() + } + TestFTableHandle(conn, cache) + } + +} +''' +"ReturnValueProcedure.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +/** Arguments for the `return_value` procedure. */ +data class ReturnValueArgs( + val foo: ULong +) { + /** Encodes these arguments to BSATN. */ + fun encode(): ByteArray { + val writer = BsatnWriter() + writer.writeU64(foo) + return writer.toByteArray() + } + + companion object { + /** Decodes [ReturnValueArgs] from BSATN. */ + fun decode(reader: BsatnReader): ReturnValueArgs { + val foo = reader.readU64() + return ReturnValueArgs(foo) + } + } +} + +object ReturnValueProcedure { + const val PROCEDURE_NAME = "return_value" + // Returns: Baz +} +''' +"SayHelloReducer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Constants for the `say_hello` reducer. */ +object SayHelloReducer { + const val REDUCER_NAME = "say_hello" +} +''' +"SleepOneSecondProcedure.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +object SleepOneSecondProcedure { + const val PROCEDURE_NAME = "sleep_one_second" + // Returns: Unit +} +''' +"TestBtreeIndexArgsReducer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Constants for the `test_btree_index_args` reducer. */ +object TestBtreeIndexArgsReducer { + const val REDUCER_NAME = "test_btree_index_args" +} +''' +"TestDTableHandle.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Col +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.RemotePersistentTable +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.TableCache +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QueryResult + +/** Client-side handle for the `test_d` table. */ +@OptIn(InternalSpacetimeApi::class) +class TestDTableHandle internal constructor( + private val conn: DbConnection, + private val tableCache: TableCache, +) : RemotePersistentTable { + companion object { + const val TABLE_NAME = "test_d" + + const val FIELD_TEST_C = "test_c" + + fun createTableCache(): TableCache { + return TableCache.withContentKey { reader -> TestD.decode(reader) } + } + } + + override fun count(): Int = tableCache.count() + override fun all(): List = tableCache.all() + override fun iter(): Sequence = tableCache.iter() + + override fun onInsert(cb: (EventContext, TestD) -> Unit) { tableCache.onInsert(cb) } + override fun removeOnInsert(cb: (EventContext, TestD) -> Unit) { tableCache.removeOnInsert(cb) } + override fun onDelete(cb: (EventContext, TestD) -> Unit) { tableCache.onDelete(cb) } + override fun onBeforeDelete(cb: (EventContext, TestD) -> Unit) { tableCache.onBeforeDelete(cb) } + + override fun removeOnDelete(cb: (EventContext, TestD) -> Unit) { tableCache.removeOnDelete(cb) } + override fun removeOnBeforeDelete(cb: (EventContext, TestD) -> Unit) { tableCache.removeOnBeforeDelete(cb) } + +} + +@OptIn(InternalSpacetimeApi::class) +class TestDCols(tableName: String) { + val testC = Col(tableName, "test_c") +} + +class TestDIxCols +''' +"TestFTableHandle.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Col +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.RemotePersistentTable +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.TableCache +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QueryResult + +/** Client-side handle for the `test_f` table. */ +@OptIn(InternalSpacetimeApi::class) +class TestFTableHandle internal constructor( + private val conn: DbConnection, + private val tableCache: TableCache, +) : RemotePersistentTable { + companion object { + const val TABLE_NAME = "test_f" + + const val FIELD_FIELD = "field" + + fun createTableCache(): TableCache { + return TableCache.withContentKey { reader -> TestFoobar.decode(reader) } + } + } + + override fun count(): Int = tableCache.count() + override fun all(): List = tableCache.all() + override fun iter(): Sequence = tableCache.iter() + + override fun onInsert(cb: (EventContext, TestFoobar) -> Unit) { tableCache.onInsert(cb) } + override fun removeOnInsert(cb: (EventContext, TestFoobar) -> Unit) { tableCache.removeOnInsert(cb) } + override fun onDelete(cb: (EventContext, TestFoobar) -> Unit) { tableCache.onDelete(cb) } + override fun onBeforeDelete(cb: (EventContext, TestFoobar) -> Unit) { tableCache.onBeforeDelete(cb) } + + override fun removeOnDelete(cb: (EventContext, TestFoobar) -> Unit) { tableCache.removeOnDelete(cb) } + override fun removeOnBeforeDelete(cb: (EventContext, TestFoobar) -> Unit) { tableCache.removeOnBeforeDelete(cb) } + +} + +@OptIn(InternalSpacetimeApi::class) +class TestFCols(tableName: String) { + val field = Col(tableName, "field") +} + +class TestFIxCols +''' +"TestReducer.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Arguments for the `test` reducer. */ +data class TestArgs( + val arg: TestA, + val arg2: TestB, + val arg3: NamespaceTestC, + val arg4: NamespaceTestF +) { + /** Encodes these arguments to BSATN. */ + fun encode(): ByteArray { + val writer = BsatnWriter() + arg.encode(writer) + arg2.encode(writer) + arg3.encode(writer) + arg4.encode(writer) + return writer.toByteArray() + } + + companion object { + /** Decodes [TestArgs] from BSATN. */ + fun decode(reader: BsatnReader): TestArgs { + val arg = TestA.decode(reader) + val arg2 = TestB.decode(reader) + val arg3 = NamespaceTestC.decode(reader) + val arg4 = NamespaceTestF.decode(reader) + return TestArgs(arg, arg2, arg3, arg4) + } + } +} + +/** Constants for the `test` reducer. */ +object TestReducer { + const val REDUCER_NAME = "test" +} +''' +"Types.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ScheduleAt +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp + +/** Data type `Baz` from the module schema. */ +data class Baz( + val field: String +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + writer.writeString(field) + } + + companion object { + /** Decodes a [Baz] from BSATN. */ + fun decode(reader: BsatnReader): Baz { + val field = reader.readString() + return Baz(field) + } + } +} + +/** Sum type `Foobar` from the module schema. */ +sealed interface Foobar { + data class Baz(val value: module_bindings.Baz) : Foobar + data object Bar : Foobar + data class Har(val value: UInt) : Foobar + + fun encode(writer: BsatnWriter) { + when (this) { + is Baz -> { + writer.writeSumTag(0u) + value.encode(writer) + } + is Bar -> writer.writeSumTag(1u) + is Har -> { + writer.writeSumTag(2u) + writer.writeU32(value) + } + } + } + + companion object { + fun decode(reader: BsatnReader): Foobar { + return when (val tag = reader.readSumTag().toInt()) { + 0 -> Baz(module_bindings.Baz.decode(reader)) + 1 -> Bar + 2 -> Har(reader.readU32()) + else -> error("Unknown Foobar tag: $tag") + } + } + } +} + +/** Data type `HasSpecialStuff` from the module schema. */ +data class HasSpecialStuff( + val identity: Identity, + val connectionId: ConnectionId +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + identity.encode(writer) + connectionId.encode(writer) + } + + companion object { + /** Decodes a [HasSpecialStuff] from BSATN. */ + fun decode(reader: BsatnReader): HasSpecialStuff { + val identity = Identity.decode(reader) + val connectionId = ConnectionId.decode(reader) + return HasSpecialStuff(identity, connectionId) + } + } +} + +/** Data type `Person` from the module schema. */ +data class Person( + val id: UInt, + val name: String, + val age: UByte +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + writer.writeU32(id) + writer.writeString(name) + writer.writeU8(age) + } + + companion object { + /** Decodes a [Person] from BSATN. */ + fun decode(reader: BsatnReader): Person { + val id = reader.readU32() + val name = reader.readString() + val age = reader.readU8() + return Person(id, name, age) + } + } +} + +/** Data type `PkMultiIdentity` from the module schema. */ +data class PkMultiIdentity( + val id: UInt, + val other: UInt +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + writer.writeU32(id) + writer.writeU32(other) + } + + companion object { + /** Decodes a [PkMultiIdentity] from BSATN. */ + fun decode(reader: BsatnReader): PkMultiIdentity { + val id = reader.readU32() + val other = reader.readU32() + return PkMultiIdentity(id, other) + } + } +} + +/** Data type `Player` from the module schema. */ +data class Player( + val identity: Identity, + val playerId: ULong, + val name: String +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + identity.encode(writer) + writer.writeU64(playerId) + writer.writeString(name) + } + + companion object { + /** Decodes a [Player] from BSATN. */ + fun decode(reader: BsatnReader): Player { + val identity = Identity.decode(reader) + val playerId = reader.readU64() + val name = reader.readString() + return Player(identity, playerId, name) + } + } +} + +/** Data type `Point` from the module schema. */ +data class Point( + val x: Long, + val y: Long +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + writer.writeI64(x) + writer.writeI64(y) + } + + companion object { + /** Decodes a [Point] from BSATN. */ + fun decode(reader: BsatnReader): Point { + val x = reader.readI64() + val y = reader.readI64() + return Point(x, y) + } + } +} + +/** Data type `PrivateTable` from the module schema. */ +data class PrivateTable( + val name: String +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + writer.writeString(name) + } + + companion object { + /** Decodes a [PrivateTable] from BSATN. */ + fun decode(reader: BsatnReader): PrivateTable { + val name = reader.readString() + return PrivateTable(name) + } + } +} + +/** Data type `RemoveTable` from the module schema. */ +data class RemoveTable( + val id: UInt +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + writer.writeU32(id) + } + + companion object { + /** Decodes a [RemoveTable] from BSATN. */ + fun decode(reader: BsatnReader): RemoveTable { + val id = reader.readU32() + return RemoveTable(id) + } + } +} + +/** Data type `RepeatingTestArg` from the module schema. */ +data class RepeatingTestArg( + val scheduledId: ULong, + val scheduledAt: ScheduleAt, + val prevTime: Timestamp +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + writer.writeU64(scheduledId) + scheduledAt.encode(writer) + prevTime.encode(writer) + } + + companion object { + /** Decodes a [RepeatingTestArg] from BSATN. */ + fun decode(reader: BsatnReader): RepeatingTestArg { + val scheduledId = reader.readU64() + val scheduledAt = ScheduleAt.decode(reader) + val prevTime = Timestamp.decode(reader) + return RepeatingTestArg(scheduledId, scheduledAt, prevTime) + } + } +} + +/** Data type `TestA` from the module schema. */ +data class TestA( + val x: UInt, + val y: UInt, + val z: String +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + writer.writeU32(x) + writer.writeU32(y) + writer.writeString(z) + } + + companion object { + /** Decodes a [TestA] from BSATN. */ + fun decode(reader: BsatnReader): TestA { + val x = reader.readU32() + val y = reader.readU32() + val z = reader.readString() + return TestA(x, y, z) + } + } +} + +/** Data type `TestB` from the module schema. */ +data class TestB( + val foo: String +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + writer.writeString(foo) + } + + companion object { + /** Decodes a [TestB] from BSATN. */ + fun decode(reader: BsatnReader): TestB { + val foo = reader.readString() + return TestB(foo) + } + } +} + +/** Data type `TestD` from the module schema. */ +data class TestD( + val testC: NamespaceTestC? +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + if (testC != null) { + writer.writeSumTag(0u) + testC.encode(writer) + } else { + writer.writeSumTag(1u) + } + } + + companion object { + /** Decodes a [TestD] from BSATN. */ + fun decode(reader: BsatnReader): TestD { + val testC = if (reader.readSumTag().toInt() == 0) NamespaceTestC.decode(reader) else null + return TestD(testC) + } + } +} + +/** Data type `TestE` from the module schema. */ +data class TestE( + val id: ULong, + val name: String +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + writer.writeU64(id) + writer.writeString(name) + } + + companion object { + /** Decodes a [TestE] from BSATN. */ + fun decode(reader: BsatnReader): TestE { + val id = reader.readU64() + val name = reader.readString() + return TestE(id, name) + } + } +} + +/** Data type `TestFoobar` from the module schema. */ +data class TestFoobar( + val field: Foobar +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + field.encode(writer) + } + + companion object { + /** Decodes a [TestFoobar] from BSATN. */ + fun decode(reader: BsatnReader): TestFoobar { + val field = Foobar.decode(reader) + return TestFoobar(field) + } + } +} + +/** Enum type `NamespaceTestC` from the module schema. */ +enum class NamespaceTestC { + Foo, + Bar; + + fun encode(writer: BsatnWriter) { + writer.writeSumTag(ordinal.toUByte()) + } + + companion object { + fun decode(reader: BsatnReader): NamespaceTestC { + val tag = reader.readSumTag().toInt() + return entries.getOrElse(tag) { error("Unknown NamespaceTestC tag: $tag") } + } + } +} + +/** Sum type `NamespaceTestF` from the module schema. */ +sealed interface NamespaceTestF { + data object Foo : NamespaceTestF + data object Bar : NamespaceTestF + data class Baz(val value: String) : NamespaceTestF + + fun encode(writer: BsatnWriter) { + when (this) { + is Foo -> writer.writeSumTag(0u) + is Bar -> writer.writeSumTag(1u) + is Baz -> { + writer.writeSumTag(2u) + writer.writeString(value) + } + } + } + + companion object { + fun decode(reader: BsatnReader): NamespaceTestF { + return when (val tag = reader.readSumTag().toInt()) { + 0 -> Foo + 1 -> Bar + 2 -> Baz(reader.readString()) + else -> error("Unknown NamespaceTestF tag: $tag") + } + } + } +} + +''' +"WithTxProcedure.kt" = ''' +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +object WithTxProcedure { + const val PROCEDURE_NAME = "with_tx" + // Returns: Unit +} +''' diff --git a/crates/smoketests/src/lib.rs b/crates/smoketests/src/lib.rs index aa54aea04ef..d2c9eee0592 100644 --- a/crates/smoketests/src/lib.rs +++ b/crates/smoketests/src/lib.rs @@ -261,6 +261,65 @@ pub fn allow_dotnet() -> bool { } } +/// Returns the path to the Gradle wrapper (`gradlew`) in the Kotlin SDK directory. +/// +/// Returns `None` if the wrapper is not found. +pub fn gradlew_path() -> Option { + static GRADLEW_PATH: OnceLock> = OnceLock::new(); + GRADLEW_PATH + .get_or_init(|| { + let gradlew = workspace_root().join("sdks/kotlin/gradlew"); + if gradlew.exists() { + Some(gradlew) + } else { + None + } + }) + .clone() +} + +/// Returns true if a JDK is available on the system. +pub fn have_java() -> bool { + static HAVE_JAVA: OnceLock = OnceLock::new(); + *HAVE_JAVA.get_or_init(|| { + Command::new("javac") + .args(["--version"]) + .output() + .map(|output| output.status.success()) + .unwrap_or(false) + }) +} + +/// Returns true if tests are configured to allow Gradle (Kotlin SDK) tests. +pub fn allow_gradle() -> bool { + let Ok(s) = std::env::var("SMOKETESTS_GRADLE") else { + return true; + }; + match s.as_str() { + "" | "0" => false, + s => s.to_lowercase() != "false", + } +} + +#[macro_export] +macro_rules! require_gradle { + () => { + if !$crate::allow_gradle() { + #[allow(clippy::disallowed_macros)] + { + eprintln!("Skipping gradle test"); + } + return; + } + if $crate::gradlew_path().is_none() { + panic!("gradlew not found in sdks/kotlin/"); + } + if !$crate::have_java() { + panic!("JDK not found (javac not on PATH)"); + } + }; +} + /// Returns true if psql (PostgreSQL client) is available on the system. pub fn have_psql() -> bool { static HAVE_PSQL: OnceLock = OnceLock::new(); diff --git a/crates/smoketests/tests/smoketests/kotlin_sdk.rs b/crates/smoketests/tests/smoketests/kotlin_sdk.rs new file mode 100644 index 00000000000..8fab5dbeb1d --- /dev/null +++ b/crates/smoketests/tests/smoketests/kotlin_sdk.rs @@ -0,0 +1,150 @@ +#![allow(clippy::disallowed_macros)] +use spacetimedb_guard::{ensure_binaries_built, SpacetimeDbGuard}; +use spacetimedb_smoketests::{gradlew_path, patch_module_cargo_to_local_bindings, require_gradle, workspace_root}; +use std::fs; +use std::process::Command; +use std::sync::Mutex; + +/// Gradle builds sharing the same project directory cannot run in parallel. +/// This mutex serializes all Kotlin smoketests that invoke gradlew on sdks/kotlin/. +static GRADLE_LOCK: Mutex<()> = Mutex::new(()); + +/// Run the Kotlin SDK unit tests (BSATN codec, type round-trips, query builder, etc.). +/// Does not require a running SpacetimeDB server. +/// Skips if gradle is not available or disabled via SMOKETESTS_GRADLE=0. +#[test] +fn test_kotlin_sdk_unit_tests() { + require_gradle!(); + let _lock = GRADLE_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + + let workspace = workspace_root(); + let cli_path = ensure_binaries_built(); + let kotlin_sdk_path = workspace.join("sdks/kotlin"); + let gradlew = gradlew_path().expect("gradlew not found"); + + // The spacetimedb Gradle plugin auto-generates bindings during compilation. + // Pass the CLI path via SPACETIMEDB_CLI so the plugin uses the freshly-built binary. + let output = Command::new(&gradlew) + .args([ + ":spacetimedb-sdk:jvmTest", + ":codegen-tests:test", + "--no-daemon", + "--no-configuration-cache", + ]) + .env("SPACETIMEDB_CLI", &cli_path) + .current_dir(&kotlin_sdk_path) + .output() + .expect("Failed to run gradlew :spacetimedb-sdk:allTests :codegen-tests:test"); + + if !output.status.success() { + panic!( + "Kotlin SDK unit tests failed:\nstdout: {}\nstderr: {}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } + + eprintln!("Kotlin SDK unit tests passed"); +} + +/// Run Kotlin SDK integration tests against a live SpacetimeDB server. +/// Spawns a local server, builds + publishes the integration test module, +/// then runs the Gradle integration tests with SPACETIMEDB_HOST set. +/// Skips if gradle is not available or disabled via SMOKETESTS_GRADLE=0. +#[test] +fn test_kotlin_integration() { + require_gradle!(); + let _lock = GRADLE_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + + let workspace = workspace_root(); + let cli_path = ensure_binaries_built(); + let kotlin_sdk_path = workspace.join("sdks/kotlin"); + let module_path = kotlin_sdk_path.join("integration-tests/spacetimedb"); + + // Isolate CLI config so we don't reuse stale tokens from the user's home config. + // This mirrors what Smoketest.spacetime_cmd() does via --config-path. + let config_dir = tempfile::tempdir().expect("Failed to create temp config dir"); + let config_path = config_dir.path().join("config.toml"); + + // Helper: build a Command with --config-path already set. + let cli = |extra_args: &[&str]| -> std::process::Output { + Command::new(&cli_path) + .arg("--config-path") + .arg(&config_path) + .args(extra_args) + .output() + .expect("Failed to run spacetime CLI command") + }; + + // Step 1: Spawn a local SpacetimeDB server + let guard = SpacetimeDbGuard::spawn_in_temp_data_dir_with_pg_port(None); + let server_url = &guard.host_url; + eprintln!("[KOTLIN-INTEGRATION] Server running at {server_url}"); + + // Step 2: Patch the module to use local bindings and build it + patch_module_cargo_to_local_bindings(&module_path).expect("Failed to patch module Cargo.toml"); + + let toolchain_src = workspace.join("rust-toolchain.toml"); + if toolchain_src.exists() { + fs::copy(&toolchain_src, module_path.join("rust-toolchain.toml")).expect("Failed to copy rust-toolchain.toml"); + } + + let output = cli(&["build", "--module-path", module_path.to_str().unwrap()]); + assert!( + output.status.success(), + "spacetime build failed:\nstdout: {}\nstderr: {}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + + // Step 4: Publish the module + let db_name = "kotlin-integration-test"; + let output = cli(&[ + "publish", + "--server", + server_url, + "--module-path", + module_path.to_str().unwrap(), + "--no-config", + "-y", + db_name, + ]); + assert!( + output.status.success(), + "spacetime publish failed:\nstdout: {}\nstderr: {}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + eprintln!("[KOTLIN-INTEGRATION] Module published as '{db_name}'"); + + // Step 5: Run Gradle integration tests + let gradlew = gradlew_path().expect("gradlew not found"); + let ws_url = server_url.replace("http://", "ws://").replace("https://", "wss://"); + + let output = Command::new(&gradlew) + .args([ + ":integration-tests:clean", + ":integration-tests:test", + "-PintegrationTests", + "--no-daemon", + "--no-configuration-cache", + "--stacktrace", + ]) + .env("SPACETIMEDB_CLI", &cli_path) + .env("SPACETIMEDB_HOST", &ws_url) + .env("SPACETIMEDB_DB_NAME", db_name) + .current_dir(&kotlin_sdk_path) + .output() + .expect("Failed to run gradle integration tests"); + + if !output.status.success() { + panic!( + "Kotlin integration tests failed:\nstdout: {}\nstderr: {}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } + + eprintln!("[KOTLIN-INTEGRATION] All integration tests passed"); + drop(guard); +} diff --git a/crates/smoketests/tests/smoketests/mod.rs b/crates/smoketests/tests/smoketests/mod.rs index f5053652dd3..96acf893e6c 100644 --- a/crates/smoketests/tests/smoketests/mod.rs +++ b/crates/smoketests/tests/smoketests/mod.rs @@ -19,6 +19,7 @@ mod domains; mod fail_initial_publish; mod filtering; mod http_egress; +mod kotlin_sdk; mod logs_level_filter; mod module_nested_op; mod modules; diff --git a/crates/smoketests/tests/smoketests/templates.rs b/crates/smoketests/tests/smoketests/templates.rs index 8de55871f58..65f24391a91 100644 --- a/crates/smoketests/tests/smoketests/templates.rs +++ b/crates/smoketests/tests/smoketests/templates.rs @@ -547,6 +547,84 @@ fn setup_rust_client_sdk(project_path: &Path) -> Result<()> { update_cargo_toml_dependency(&project_path.join("Cargo.toml"), "spacetimedb-sdk", &sdk_path) } +/// Wires a Kotlin template project to the local Kotlin SDK via `includeBuild` +/// and sets the CLI path in the `spacetimedb` plugin configuration. +fn setup_kotlin_client_sdk(project_path: &Path) -> Result<()> { + let workspace = workspace_root(); + let kotlin_sdk_path = workspace.join("sdks/kotlin"); + let cli_path = spacetimedb_guard::ensure_binaries_built(); + + // Uncomment includeBuild lines in settings.gradle.kts + let settings_path = project_path.join("settings.gradle.kts"); + let settings = fs::read_to_string(&settings_path).with_context(|| format!("Failed to read {:?}", settings_path))?; + let sdk_path_str = kotlin_sdk_path.display().to_string().replace('\\', "/"); + let patched = settings + .replace( + "// includeBuild(\"/spacetimedb-gradle-plugin\")", + &format!("includeBuild(\"{}/spacetimedb-gradle-plugin\")", sdk_path_str), + ) + .replace( + "// includeBuild(\"\")", + &format!("includeBuild(\"{}\")", sdk_path_str), + ); + fs::write(&settings_path, patched).with_context(|| format!("Failed to write {:?}", settings_path))?; + + // Find the build.gradle.kts that applies the spacetimedb plugin (not `apply false`) + // and append a spacetimedb {} block with the CLI path. + let cli_path_str = cli_path.display().to_string().replace('\\', "/"); + let plugin_build_file = find_spacetimedb_plugin_build_file(project_path).with_context(|| { + format!( + "No build.gradle.kts applying the spacetimedb plugin found in {:?}", + project_path + ) + })?; + let content = + fs::read_to_string(&plugin_build_file).with_context(|| format!("Failed to read {:?}", plugin_build_file))?; + let patched = format!( + "{}\nspacetimedb {{\n cli.set(file(\"{}\"))\n}}\n", + content, cli_path_str + ); + fs::write(&plugin_build_file, patched).with_context(|| format!("Failed to write {:?}", plugin_build_file))?; + + // Copy Gradle wrapper from the SDK + let gradlew_src = kotlin_sdk_path.join("gradlew"); + if gradlew_src.exists() { + fs::copy(&gradlew_src, project_path.join("gradlew")).context("Failed to copy gradlew")?; + let wrapper_src = kotlin_sdk_path.join("gradle/wrapper"); + let wrapper_dst = project_path.join("gradle/wrapper"); + fs::create_dir_all(&wrapper_dst).context("Failed to create gradle/wrapper")?; + for entry in fs::read_dir(&wrapper_src) + .context("Failed to read gradle/wrapper")? + .flatten() + { + fs::copy(entry.path(), wrapper_dst.join(entry.file_name())) + .context("Failed to copy gradle wrapper file")?; + } + } + + Ok(()) +} + +/// Recursively searches for a `build.gradle.kts` that applies the spacetimedb +/// plugin (not with `apply false`). +fn find_spacetimedb_plugin_build_file(dir: &Path) -> Result { + for entry in fs::read_dir(dir).with_context(|| format!("Failed to read {:?}", dir))? { + let entry = entry?; + let path = entry.path(); + if path.is_dir() { + if let Ok(found) = find_spacetimedb_plugin_build_file(&path) { + return Ok(found); + } + } else if path.file_name().is_some_and(|n| n == "build.gradle.kts") { + let content = fs::read_to_string(&path)?; + if content.contains("alias(libs.plugins.spacetimedb)") && !content.contains("spacetimedb) apply false") { + return Ok(path); + } + } + } + bail!("spacetimedb plugin not found in {:?}", dir) +} + /// Creates a local `nuget.config`, packs all required SpacetimeDB C# packages /// from source, and registers them as local NuGet sources. fn setup_csharp_nuget(project_path: &Path) -> Result { @@ -670,6 +748,28 @@ fn test_rust_template(test: &Smoketest, template: &Template, project_path: &Path String::from_utf8_lossy(&output.stderr) ); } + } else if template.client_lang.as_deref() == Some("kotlin") { + setup_kotlin_client_sdk(project_path)?; + let gradlew = spacetimedb_smoketests::gradlew_path() + .context("gradlew not found — cannot build Kotlin template client")?; + let output = Command::new(&gradlew) + .args([ + "compileKotlin", + "--no-daemon", + "--no-configuration-cache", + "--stacktrace", + ]) + .current_dir(project_path) + .output() + .context("Failed to run gradlew compileKotlin")?; + if !output.status.success() { + bail!( + "gradle compileKotlin for {} client failed:\nstdout: {}\nstderr: {}", + template.id, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } } Ok(()) } diff --git a/docs/docs/00100-intro/00100-getting-started/00300-language-support.md b/docs/docs/00100-intro/00100-getting-started/00300-language-support.md index 6f5ab3a9f5e..31206107358 100644 --- a/docs/docs/00100-intro/00100-getting-started/00300-language-support.md +++ b/docs/docs/00100-intro/00100-getting-started/00300-language-support.md @@ -19,6 +19,7 @@ SpacetimeDB modules define your database schema and server-side business logic. - **[Rust](../../00200-core-concepts/00600-clients/00500-rust-reference.md)** - [(Quickstart)](../00200-quickstarts/00500-rust.md) - **[C#](../../00200-core-concepts/00600-clients/00600-csharp-reference.md)** - [(Quickstart)](../00200-quickstarts/00600-c-sharp.md) - **[TypeScript](../../00200-core-concepts/00600-clients/00700-typescript-reference.md)** - [(Quickstart)](../00200-quickstarts/00400-typescript.md) +- **[Kotlin](../../00200-core-concepts/00600-clients/00900-kotlin-reference.md)** - Kotlin Multiplatform [(Quickstart)](../00200-quickstarts/00800-kotlin.md) - **[Unreal Engine](../../00200-core-concepts/00600-clients/00800-unreal-reference.md)** - C++ and Blueprint support [(Tutorial)](../00300-tutorials/00400-unreal-tutorial/00200-part-1.md) ### Unity diff --git a/docs/docs/00100-intro/00100-getting-started/00500-faq.md b/docs/docs/00100-intro/00100-getting-started/00500-faq.md index 2552dd1b630..e288359895f 100644 --- a/docs/docs/00100-intro/00100-getting-started/00500-faq.md +++ b/docs/docs/00100-intro/00100-getting-started/00500-faq.md @@ -137,6 +137,7 @@ Client SDKs are available for: - **Rust** - **C#** (including Unity) - **TypeScript** (including React, Vue, Svelte, Angular, and more) +- **Kotlin** (Kotlin Multiplatform) - **C++** (Unreal Engine) SpacetimeDB 2.0 also includes a **type-safe query builder** for client-side subscriptions, so you do not have to write raw SQL strings if you prefer not to. diff --git a/docs/docs/00100-intro/00200-quickstarts/00800-kotlin.md b/docs/docs/00100-intro/00200-quickstarts/00800-kotlin.md new file mode 100644 index 00000000000..0779aacfa10 --- /dev/null +++ b/docs/docs/00100-intro/00200-quickstarts/00800-kotlin.md @@ -0,0 +1,193 @@ +--- +title: Kotlin Quickstart +sidebar_label: Kotlin +slug: /quickstarts/kotlin +hide_table_of_contents: true +--- + +import { InstallCardLink } from "@site/src/components/InstallCardLink"; +import { StepByStep, Step, StepText, StepCode } from "@site/src/components/Steps"; + + +Get a SpacetimeDB Kotlin app running in under 5 minutes. + +This quickstart uses the `basic-kt` template, a JVM-only console app. For a Kotlin Multiplatform project targeting Android and Desktop, use the `compose-kt` template instead. + +## Prerequisites + +- JDK 21+ installed +- [SpacetimeDB CLI](https://spacetimedb.com/install) installed + + + +--- + + + + + Run the `spacetime dev` command to create a new project with a Kotlin client and Rust server module. + + This will start the local SpacetimeDB server, compile and publish your module, and generate Kotlin client bindings. + + +```bash +spacetime dev --template basic-kt +``` + + + + + + Your project contains a Rust server module and a Kotlin client. The Gradle plugin auto-generates typed bindings into `build/generated/` on compile. + + +``` +my-spacetime-app/ +├── spacetimedb/ # Your SpacetimeDB module (Rust) +│ ├── Cargo.toml +│ └── src/lib.rs # Server-side logic +├── src/main/kotlin/ +│ └── Main.kt # Client application +├── build/generated/spacetimedb/ +│ └── bindings/ # Auto-generated types +├── build.gradle.kts +└── settings.gradle.kts +``` + + + + + + Open `spacetimedb/src/lib.rs` to see the module code. The template includes a `Person` table, three lifecycle reducers (`init`, `client_connected`, `client_disconnected`), and two application reducers: `add` to insert a person, and `say_hello` to greet everyone. + + Tables store your data. Reducers are functions that modify data — they're the only way to write to the database. + + +```rust +use spacetimedb::{ReducerContext, Table}; + +#[spacetimedb::table(accessor = person, public)] +pub struct Person { + #[primary_key] + #[auto_inc] + id: u64, + name: String, +} + +#[spacetimedb::reducer(init)] +pub fn init(_ctx: &ReducerContext) { + // Called when the module is initially published +} + +#[spacetimedb::reducer(client_connected)] +pub fn identity_connected(_ctx: &ReducerContext) { + // Called everytime a new client connects +} + +#[spacetimedb::reducer(client_disconnected)] +pub fn identity_disconnected(_ctx: &ReducerContext) { + // Called everytime a client disconnects +} + +#[spacetimedb::reducer] +pub fn add(ctx: &ReducerContext, name: String) { + ctx.db.person().insert(Person { id: 0, name }); +} + +#[spacetimedb::reducer] +pub fn say_hello(ctx: &ReducerContext) { + for person in ctx.db.person().iter() { + log::info!("Hello, {}!", person.name); + } + log::info!("Hello, World!"); +} +``` + + + + + + Open `src/main/kotlin/Main.kt`. The client connects to SpacetimeDB, subscribes to tables, registers callbacks, and calls reducers — all with generated type-safe bindings. + + +```kotlin +suspend fun main() { + val host = System.getenv("SPACETIMEDB_HOST") ?: "ws://localhost:3000" + val httpClient = HttpClient(OkHttp) { install(WebSockets) } + + DbConnection.Builder() + .withHttpClient(httpClient) + .withUri(host) + .withDatabaseName(module_bindings.SpacetimeConfig.DATABASE_NAME) + .withModuleBindings() + .onConnect { conn, identity, _ -> + println("Connected to SpacetimeDB!") + println("Identity: ${identity.toHexString().take(16)}...") + + conn.db.person.onInsert { _, person -> + println("New person: ${person.name}") + } + + conn.reducers.onAdd { ctx, name -> + println("[onAdd] Added person: $name (status=${ctx.status})") + } + + conn.subscriptionBuilder() + .onError { _, error -> println("Subscription error: $error") } + .subscribeToAllTables() + + conn.reducers.add("Alice") { ctx -> + println("[one-shot] Add completed: status=${ctx.status}") + conn.reducers.sayHello() + } + } + .onDisconnect { _, error -> + if (error != null) { + println("Disconnected with error: $error") + } else { + println("Disconnected") + } + } + .onConnectError { _, error -> + println("Connection error: $error") + } + .build() + .use { delay(5.seconds) } +} +``` + + + + + + Open a new terminal and navigate to your project directory. Then use the SpacetimeDB CLI to call reducers and query your data directly. + + +```bash +cd my-spacetime-app + +# Call the add reducer to insert a person +spacetime call add Alice + +# Query the person table +spacetime sql "SELECT * FROM person" + id | name +----+--------- + 1 | "Alice" + +# Call say_hello to greet everyone +spacetime call say_hello + +# View the module logs +spacetime logs +2025-01-13T12:00:00.000000Z INFO: Hello, Alice! +2025-01-13T12:00:00.000000Z INFO: Hello, World! +``` + + + + +## Next steps + +- Read the [Kotlin SDK Reference](../../00200-core-concepts/00600-clients/00900-kotlin-reference.md) for detailed API docs +- Try the `compose-kt` template (`spacetime init --template compose-kt`) for a full KMP chat client with Compose Multiplatform diff --git a/docs/docs/00200-core-concepts/00100-databases/00200-spacetime-dev.md b/docs/docs/00200-core-concepts/00100-databases/00200-spacetime-dev.md index f8384e5bc6e..53c2ae4e7f7 100644 --- a/docs/docs/00200-core-concepts/00100-databases/00200-spacetime-dev.md +++ b/docs/docs/00200-core-concepts/00100-databases/00200-spacetime-dev.md @@ -68,8 +68,10 @@ Choose from several built-in templates: - `basic-ts` - Basic TypeScript client and server stubs - `basic-cs` - Basic C# client and server stubs - `basic-rs` - Basic Rust client and server stubs +- `basic-kt` - Basic Kotlin client and Rust server stubs - `basic-cpp` - Basic C++ server stubs - `react-ts` - React web app with TypeScript server +- `compose-kt` - Compose Multiplatform chat app with Rust server - `chat-console-rs` - Complete Rust chat implementation - `chat-console-cs` - Complete C# chat implementation - `chat-react-ts` - Complete TypeScript chat implementation diff --git a/docs/docs/00200-core-concepts/00600-clients.md b/docs/docs/00200-core-concepts/00600-clients.md index e149b3b63a0..e2157e6d765 100644 --- a/docs/docs/00200-core-concepts/00600-clients.md +++ b/docs/docs/00200-core-concepts/00600-clients.md @@ -12,6 +12,7 @@ SpacetimeDB provides client SDKs for multiple languages: - [Rust](./00600-clients/00500-rust-reference.md) - [(Quickstart)](../00100-intro/00200-quickstarts/00500-rust.md) - [C#](./00600-clients/00600-csharp-reference.md) - [(Quickstart)](../00100-intro/00200-quickstarts/00600-c-sharp.md) - [TypeScript](./00600-clients/00700-typescript-reference.md) - [(Quickstart)](../00100-intro/00200-quickstarts/00400-typescript.md) +- [Kotlin](./00600-clients/00900-kotlin-reference.md) - [(Quickstart)](../00100-intro/00200-quickstarts/00800-kotlin.md) - [Unreal](./00600-clients/00800-unreal-reference.md) - [(Tutorial)](../00100-intro/00300-tutorials/00400-unreal-tutorial/index.md) ## Getting Started diff --git a/docs/docs/00200-core-concepts/00600-clients/00900-kotlin-reference.md b/docs/docs/00200-core-concepts/00600-clients/00900-kotlin-reference.md new file mode 100644 index 00000000000..85738bcbd8a --- /dev/null +++ b/docs/docs/00200-core-concepts/00600-clients/00900-kotlin-reference.md @@ -0,0 +1,352 @@ +--- +title: Kotlin Reference +slug: /clients/kotlin +--- + +The SpacetimeDB client SDK for Kotlin Multiplatform, targeting Android, JVM (Desktop), and iOS/Native. + +Two templates are available: +- `basic-kt` — JVM-only console app (simplest starting point) +- `compose-kt` — Compose Multiplatform app targeting Android and Desktop + +Before diving into the reference, you may want to review: + +- [Generating Client Bindings](./00200-codegen.md) - How to generate Kotlin bindings from your module +- [Connecting to SpacetimeDB](./00300-connection.md) - Establishing and managing connections +- [SDK API Reference](./00400-sdk-api.md) - Core concepts that apply across all SDKs + +| Name | Description | +| ----------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------- | +| [Project setup](#project-setup) | Configure your Kotlin project to use the SpacetimeDB Kotlin SDK. | +| [Generate module bindings](#generate-module-bindings) | Generated types and how the Gradle plugin automates codegen. | +| [`DbConnection` type](#type-dbconnection) | A connection to a remote database. | +| [`EventContext` type](#type-eventcontext) | Context available in row and reducer callbacks. | +| [Access the client cache](#access-the-client-cache) | Query subscribed rows and register row callbacks. | +| [Observe and invoke reducers](#observe-and-invoke-reducers) | Call reducers and register callbacks for reducer events. | +| [Subscribe to queries](#subscribe-to-queries) | Subscribe to table data using the type-safe query builder. | +| [Identify a client](#identify-a-client) | Types for identifying users and client connections. | +| [Type mappings](#type-mappings) | How SpacetimeDB types map to Kotlin types. | + +## Project setup + +### Using `spacetime dev` (recommended) + +The fastest way to get started: + +```bash +# JVM-only console app +spacetime dev --template basic-kt + +# Compose Multiplatform (Android + Desktop) +spacetime dev --template compose-kt +``` + +Both templates come with the Gradle plugin pre-configured. + +### Manual setup + +Add the SpacetimeDB Gradle plugin to your `build.gradle.kts`: + +```kotlin +plugins { + id("com.clockworklabs.spacetimedb") +} + +spacetimedb { + modulePath.set(file("spacetimedb")) +} + +dependencies { + implementation("com.clockworklabs:spacetimedb-sdk") +} +``` + +In `settings.gradle.kts`, add the plugin repository: + +```kotlin +pluginManagement { + repositories { + gradlePluginPortal() + mavenCentral() + } +} +``` + +The SDK requires JDK 21+ and uses [Ktor](https://ktor.io/) for WebSocket transport. Add a Ktor engine dependency for your platform: + +```kotlin +// JVM / Android +implementation("io.ktor:ktor-client-okhttp:3.4.1") + +// iOS / Native +implementation("io.ktor:ktor-client-darwin:3.4.1") +``` + +```kotlin +// All platforms need the WebSockets plugin +implementation("io.ktor:ktor-client-websockets:3.4.1") +``` + +## Generate module bindings + +The SpacetimeDB Gradle plugin automatically generates Kotlin bindings when you compile. Bindings are generated into `build/generated/spacetimedb/bindings/` and wired into the Kotlin compilation automatically. + +Generated files include: + +| File | Description | +| ---- | ----------- | +| `Types.kt` | All user-defined types (`data class`, `sealed interface`, `enum class`) | +| `{Table}TableHandle.kt` | Table handle with field name constants, cache accessors, and callbacks | +| `{Reducer}Reducer.kt` | Reducer args `data class` and name constant | +| `RemoteTables.kt` | Aggregates all table accessors | +| `RemoteReducers.kt` | Reducer call stubs with one-shot callbacks | +| `RemoteProcedures.kt` | Procedure call methods and callback registration | +| `Module.kt` | Module descriptor, `QueryBuilder`, and `subscribeToAllTables` extension | + +You can also generate bindings manually: + +```bash +spacetime generate --lang kotlin --out-dir src/main/kotlin/module_bindings --module-path spacetimedb +``` + +## Type `DbConnection` + +A `DbConnection` represents a live WebSocket connection to a SpacetimeDB database. Create one using the builder: + +```kotlin +val httpClient = HttpClient(OkHttp) { install(WebSockets) } + +val conn = DbConnection.Builder() + .withHttpClient(httpClient) + .withUri("ws://localhost:3000") + .withDatabaseName("my-database") + .withModuleBindings() + .onConnect { conn, identity, token -> + // Connected — register callbacks, subscribe, call reducers + } + .onDisconnect { conn, error -> + // Disconnected — error is null for clean disconnects + } + .onConnectError { conn, error -> + // Connection failed + } + .build() +``` + +### Builder methods + +| Method | Description | +| ------ | ----------- | +| `withHttpClient(client)` | Ktor `HttpClient` with WebSockets installed | +| `withUri(uri)` | WebSocket URL (e.g. `ws://localhost:3000`) | +| `withDatabaseName(name)` | Database name or address | +| `withToken(token)` | Auth token (nullable, for reconnecting with saved identity) | +| `withModuleBindings()` | Generated extension that registers the module descriptor | +| `onConnect(cb)` | Called after successful connection with `(DbConnectionView, Identity, String)` | +| `onDisconnect(cb)` | Called on disconnect with `(DbConnectionView, Throwable?)` | +| `onConnectError(cb)` | Called on connection failure with `(DbConnectionView, Throwable)` | +| `build()` | Suspending — connects and returns the `DbConnection` | + +### Using `use` for automatic cleanup + +The SDK provides a `use` extension that keeps the connection alive and disconnects when the block completes: + +```kotlin +conn.use { + delay(Duration.INFINITE) // Keep alive until cancelled +} +``` + +### Accessing generated modules + +Inside callbacks, the connection exposes generated accessors: + +```kotlin +conn.db.person // Table handle for the "person" table +conn.reducers.add() // Call the "add" reducer +``` + +These are generated extension properties — `db`, `reducers`, and `procedures`. + +## Type `EventContext` + +Callbacks receive an `EventContext` that provides access to the database and metadata about the event: + +```kotlin +conn.db.person.onInsert { ctx, person -> + // ctx.db, ctx.reducers, ctx.procedures are available + // ctx is an EventContext +} +``` + +Reducer callbacks receive an `EventContext.Reducer` with additional fields: + +```kotlin +conn.reducers.onAdd { ctx, name -> + ctx.status // Status (Committed, Failed) + ctx.callerIdentity // Identity of the caller +} +``` + +## Access the client cache + +Each table handle provides methods to read cached rows and register callbacks. + +### Read rows + +```kotlin +conn.db.person.count() // Number of cached rows +conn.db.person.all() // List of all cached rows +conn.db.person.iter() // Sequence for lazy iteration +``` + +### Row callbacks + +```kotlin +// Called when a row is inserted +conn.db.person.onInsert { ctx, person -> + println("Inserted: ${person.name}") +} + +// Called when a row is deleted +conn.db.person.onDelete { ctx, person -> + println("Deleted: ${person.name}") +} + +// Called when a row is updated (tables with primary keys only) +conn.db.person.onUpdate { ctx, oldPerson, newPerson -> + println("Updated: ${oldPerson.name} -> ${newPerson.name}") +} + +// Called before a row is deleted (for pre-delete logic) +conn.db.person.onBeforeDelete { ctx, person -> + println("About to delete: ${person.name}") +} +``` + +Remove callbacks by passing the same function reference: + +```kotlin +val cb: (EventContext, Person) -> Unit = { _, p -> println(p.name) } +conn.db.person.onInsert(cb) +conn.db.person.removeOnInsert(cb) +``` + +### Index lookups + +For tables with unique indexes: + +```kotlin +conn.db.person.id.find(42u) // Person? — lookup by unique index +``` + +For tables with BTree indexes: + +```kotlin +conn.db.person.nameIdx.filter("Alice") // Set — filter by index +``` + +## Observe and invoke reducers + +### Call a reducer + +```kotlin +conn.reducers.add("Alice") +``` + +### Call with a one-shot callback + +```kotlin +conn.reducers.add("Alice") { ctx -> + println("Add completed: status=${ctx.status}") +} +``` + +The one-shot callback fires only for this specific call. + +### Observe all calls to a reducer + +```kotlin +conn.reducers.onAdd { ctx, name -> + println("Someone called add($name), status=${ctx.status}") +} +``` + +## Subscribe to queries + +### Subscribe to all tables + +```kotlin +conn.subscriptionBuilder() + .onError { _, error -> println("Subscription error: $error") } + .subscribeToAllTables() +``` + +### Type-safe query builder + +Use the generated `QueryBuilder` for type-safe subscriptions: + +```kotlin +conn.subscriptionBuilder() + .addQuery { qb -> qb.person().where { cols -> cols.name.eq("Alice") } } + .onApplied { println("Subscription applied") } + .subscribe() +``` + +The query builder supports: + +| Method | Description | +| ------ | ----------- | +| `where { cols -> expr }` | Filter rows by column predicates | +| `leftSemijoin(other) { l, r -> expr }` | Keep left rows that match right | +| `rightSemijoin(other) { l, r -> expr }` | Keep right rows that match left | + +Column predicates: `eq`, `neq`, `lt`, `lte`, `gt`, `gte`, combined with `and` / `or`. + +## Identify a client + +### `Identity` + +A unique identifier for a user, consistent across connections. Represented as a 32-byte value. + +```kotlin +val hex = identity.toHexString() +``` + +### `ConnectionId` + +Identifies a specific connection (a user can have multiple). + +## Type mappings + +| SpacetimeDB Type | Kotlin Type | +| ---------------- | ----------- | +| `bool` | `Boolean` | +| `u8` | `UByte` | +| `u16` | `UShort` | +| `u32` | `UInt` | +| `u64` | `ULong` | +| `u128` | `UInt128` | +| `u256` | `UInt256` | +| `i8` | `Byte` | +| `i16` | `Short` | +| `i32` | `Int` | +| `i64` | `Long` | +| `i128` | `Int128` | +| `i256` | `Int256` | +| `f32` | `Float` | +| `f64` | `Double` | +| `String` | `String` | +| `Vec` / `bytes` | `ByteArray` | +| `Vec` / `Array` | `List` | +| `Option` | `T?` | +| `Identity` | `Identity` | +| `ConnectionId` | `ConnectionId` | +| `Timestamp` | `Timestamp` | +| `TimeDuration` | `TimeDuration` | +| `ScheduleAt` | `ScheduleAt` | +| `Uuid` | `SpacetimeUuid` | +| `Result` | `SpacetimeResult` | +| Product types | `data class` | +| Sum types (all unit) | `enum class` | +| Sum types (mixed) | `sealed interface` | diff --git a/docs/src/components/QuickstartLinks.tsx b/docs/src/components/QuickstartLinks.tsx index befd2d48a36..3656518aeb2 100644 --- a/docs/src/components/QuickstartLinks.tsx +++ b/docs/src/components/QuickstartLinks.tsx @@ -17,6 +17,7 @@ import NodeJSLogo from '@site/static/images/logos/nodejs-logo.svg'; import TypeScriptLogo from '@site/static/images/logos/typescript-logo.svg'; import RustLogo from '@site/static/images/logos/rust-logo.svg'; import CSharpLogo from '@site/static/images/logos/csharp-logo.svg'; +import KotlinLogo from '@site/static/images/logos/kotlin-logo.svg'; import CppLogo from '@site/static/images/logos/cpp-logo.svg'; const ALL_ITEMS: Item[] = [ @@ -110,6 +111,12 @@ const ALL_ITEMS: Item[] = [ docId: 'intro/quickstarts/c-sharp', label: 'C#', }, + { + icon: , + href: 'quickstarts/kotlin', + docId: 'intro/quickstarts/kotlin', + label: 'Kotlin', + }, { icon: , href: 'quickstarts/c-plus-plus', diff --git a/docs/static/images/logos/kotlin-logo.svg b/docs/static/images/logos/kotlin-logo.svg new file mode 100644 index 00000000000..b30f7a27213 --- /dev/null +++ b/docs/static/images/logos/kotlin-logo.svg @@ -0,0 +1,8 @@ + + + diff --git a/sdks/kotlin/.gitignore b/sdks/kotlin/.gitignore new file mode 100644 index 00000000000..34831c1718b --- /dev/null +++ b/sdks/kotlin/.gitignore @@ -0,0 +1,44 @@ +*.iml +.kotlin/ +.gradle/ +**/build/ +xcuserdata/ +!src/**/build/ +local.properties +.idea/ +.DS_Store +captures +.externalNativeBuild +.cxx +*.xcodeproj/* +!*.xcodeproj/project.pbxproj +!*.xcodeproj/xcshareddata/ +!*.xcodeproj/project.xcworkspace/ +!*.xcworkspace/contents.xcworkspacedata +**/xcshareddata/WorkspaceSettings.xcsettings + +# Logs +*.log + +# Database files +*.db +*.db-shm +*.db-wal + +# Server data directory +/data/ +server/data/ + +# Environment files +.env +.env.local + +# OS specific +Thumbs.db +.Trashes +._* + +# IDE specific +*.swp +*~ +.vscode/ diff --git a/sdks/kotlin/README.md b/sdks/kotlin/README.md new file mode 100644 index 00000000000..2e87f2487a0 --- /dev/null +++ b/sdks/kotlin/README.md @@ -0,0 +1,263 @@ +# SpacetimeDB Kotlin SDK + +Kotlin Multiplatform client SDK for [SpacetimeDB](https://spacetimedb.com). Connects to a SpacetimeDB module over WebSocket, synchronizes table state into an in-memory client cache, and provides typed access to tables, reducers, and procedures via generated bindings. + +## Supported Platforms + +| Platform | Minimum Version | +|----------|----------------| +| JVM | 21 | +| Android | API 26 | +| iOS | arm64 / x64 / simulator-arm64 | + +The SDK uses [Ktor](https://ktor.io/) for WebSocket transport. You must provide an `HttpClient` with a platform-appropriate engine (e.g. OkHttp for JVM/Android, Darwin for iOS) and the WebSockets plugin installed. + +## Installation + +### Gradle Plugin (recommended) + +Apply the plugin to your module's `build.gradle.kts`: + +```kotlin +plugins { + id("com.clockworklabs.spacetimedb") +} + +spacetimedb { + // Path to spacetimedb-cli binary (defaults to "spacetimedb-cli" on PATH) + cli.set(file("/path/to/spacetimedb-cli")) + // Path to your SpacetimeDB module directory (defaults to "spacetimedb/") + modulePath.set(file("spacetimedb/")) +} +``` + +The plugin registers a `generateSpacetimeBindings` task that runs `spacetimedb-cli generate --lang kotlin` and wires the output into Kotlin compilation automatically. + +### Manual Setup + +Add the SDK dependency and generate bindings with the CLI: + +```kotlin +// build.gradle.kts +dependencies { + implementation("com.clockworklabs:spacetimedb-kotlin-sdk:0.1.0") +} +``` + +```bash +spacetimedb-cli generate \ + --lang kotlin \ + --out-dir src/main/kotlin/module_bindings/ \ + --module-path path/to/your/spacetimedb/module +``` + +## Quick Start + +```kotlin +import io.ktor.client.HttpClient +import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.plugins.websocket.WebSockets +import module_bindings.* + +suspend fun main() { + val httpClient = HttpClient(OkHttp) { install(WebSockets) } + + val conn = DbConnection.Builder() + .withHttpClient(httpClient) + .withUri("ws://localhost:3000") + .withDatabaseName("my_module") + .withModuleBindings() + .onConnect { conn, identity, token -> + println("Connected as $identity") + + // Subscribe to tables + conn.subscriptionBuilder() + .addQuery { qb -> qb.person() } + .subscribe() + } + .onDisconnect { _, reason -> + println("Disconnected: $reason") + } + .build() + + // Register table callbacks + conn.db.person.onInsert { ctx, person -> + println("New person: ${person.name}") + } + + // Call reducers + conn.reducers.add("Alice") + + // Register reducer callbacks + conn.reducers.onAdd { ctx, name -> + println("add reducer called with: $name (status: ${ctx.status})") + } +} +``` + +## Generated Bindings + +Running codegen produces the following files: + +| File | Contents | +|------|----------| +| `Types.kt` | Data classes for all user-defined types | +| `*TableHandle.kt` | Table handle with callbacks, queries, and column metadata | +| `*Reducer.kt` | Reducer args data class and name constant | +| `RemoteTables.kt` | Aggregates all table accessors | +| `RemoteReducers.kt` | Reducer call stubs and per-reducer callbacks | +| `RemoteProcedures.kt` | Procedure call stubs | +| `Module.kt` | Module metadata, extension properties (`conn.db`, `conn.reducers`, `conn.procedures`), query builder | + +Extension properties are generated on both `DbConnection` and `EventContext`, so you can access `ctx.db.person` directly inside callbacks. + +## Connection Lifecycle + +### Builder Options + +```kotlin +DbConnection.Builder() + .withHttpClient(httpClient) // Ktor HttpClient with WebSockets (required) + .withUri("ws://localhost:3000") // WebSocket URI (required) + .withDatabaseName("my_module") // Module name or address (required) + .withModuleBindings() // Register generated module (required) + .withToken(savedToken) // Auth token for identity reuse + .withCompression(CompressionMode.GZIP) // Enable GZIP compression + .withLightMode(true) // Light mode (reduced server-side state) + .withCallbackDispatcher(Dispatchers.Main)// Dispatch callbacks on a specific dispatcher + .onConnect { conn, identity, token -> } // Fires once on successful connection + .onDisconnect { conn, reason -> } // Fires on disconnect + .onConnectError { conn, error -> } // Fires if connection fails + .build() // Returns connected DbConnection +``` + +### States + +A `DbConnection` transitions through these states: + +``` +DISCONNECTED → CONNECTING → CONNECTED → CLOSED +``` + +Once `CLOSED`, the connection cannot be reused. Create a new `DbConnection` to reconnect. + +### Reconnection + +The SDK does not reconnect automatically. Implement retry logic at the application level: + +```kotlin +suspend fun connectWithRetry(httpClient: HttpClient, maxAttempts: Int = 5): DbConnection { + repeat(maxAttempts) { attempt -> + try { + return DbConnection.Builder() + .withHttpClient(httpClient) + .withUri("ws://localhost:3000") + .withDatabaseName("my_module") + .withModuleBindings() + .build() + } catch (e: Exception) { + if (attempt == maxAttempts - 1) throw e + delay(1000L * (attempt + 1)) // linear backoff + } + } + error("unreachable") +} +``` + +## Subscriptions + +### SQL-string subscriptions + +```kotlin +// Subscribe to all rows +conn.subscribe("SELECT person.* FROM person") + +// Multiple queries +conn.subscribe( + "SELECT person.* FROM person", + "SELECT item.* FROM item", +) +``` + +### Type-safe query builder + +```kotlin +conn.subscriptionBuilder() + .addQuery { qb -> qb.person() } // all rows + .addQuery { qb -> qb.person().where { c -> c.name.eq("Alice") } } // filtered + .onApplied { ctx -> println("Subscription applied") } + .onError { ctx, err -> println("Subscription error: $err") } + .subscribe() +``` + +## Table Callbacks + +```kotlin +// Fires for each inserted row +conn.db.person.onInsert { ctx, person -> } + +// Fires for each deleted row (persistent tables only) +conn.db.person.onDelete { ctx, person -> } + +// Fires before delete (useful for cleanup/animation triggers) +conn.db.person.onBeforeDelete { ctx, person -> } +``` + +Remove callbacks by passing the same function reference to the corresponding `removeOn*` method. + +## Reading Table Data + +```kotlin +// All cached rows +val people: List = conn.db.person.all() + +// Row count +val count: Int = conn.db.person.count() + +// Lazy iteration +conn.db.person.iter().forEach { person -> println(person.name) } +``` + +## One-Off Queries + +Execute a query outside of subscriptions: + +```kotlin +// Callback-based +conn.oneOffQuery("SELECT person.* FROM person") { result -> } + +// Suspend (with optional timeout) +val result = conn.oneOffQuery("SELECT person.* FROM person", timeout = 5.seconds) + +``` + +## Thread Safety + +The SDK is safe to use from any thread/coroutine: + +- **Client cache**: All row storage uses atomic references over persistent immutable collections (`kotlinx.collections.immutable`). No locks are needed — each reader gets a consistent snapshot via atomic reference reads. +- **Callback lists**: Stored as atomic `PersistentList` references. Adding/removing callbacks and iterating over them are lock-free operations. +- **Connection state**: Managed via atomic compare-and-swap, preventing double-connect or double-disconnect races. + +### Callback Dispatcher + +By default, callbacks execute on the WebSocket receive coroutine. To dispatch callbacks on a specific thread (e.g., the main/UI thread): + +```kotlin +DbConnection.Builder() + .withHttpClient(httpClient) + .withCallbackDispatcher(Dispatchers.Main) + // ... + .build() +``` + +This applies to all table, reducer, subscription, and connection callbacks. + +## Dependencies + +| Library | Version | Purpose | +|---------|---------|---------| +| Ktor Client | 3.4.1 | WebSocket transport | +| kotlinx-coroutines | 1.10.2 | Async runtime | +| kotlinx-atomicfu | 0.31.0 | Lock-free atomics | +| kotlinx-collections-immutable | 0.4.0 | Persistent data structures | diff --git a/sdks/kotlin/build.gradle.kts b/sdks/kotlin/build.gradle.kts new file mode 100644 index 00000000000..d683254797c --- /dev/null +++ b/sdks/kotlin/build.gradle.kts @@ -0,0 +1,23 @@ +buildscript { + val SPACETIMEDB_CLI by extra("/home/fromml/Projects/SpacetimeDB/target/release/spacetimedb-cli") +} +plugins { + alias(libs.plugins.kotlinJvm) apply false + alias(libs.plugins.kotlinMultiplatform) apply false + alias(libs.plugins.androidKotlinMultiplatformLibrary) apply false +} + +subprojects { + afterEvaluate { + plugins.withId("org.jetbrains.kotlin.multiplatform") { + extensions.configure { + jvmToolchain(21) + } + } + plugins.withId("org.jetbrains.kotlin.jvm") { + extensions.configure { + jvmToolchain(21) + } + } + } +} diff --git a/sdks/kotlin/codegen-tests/.gitignore b/sdks/kotlin/codegen-tests/.gitignore new file mode 100644 index 00000000000..567609b1234 --- /dev/null +++ b/sdks/kotlin/codegen-tests/.gitignore @@ -0,0 +1 @@ +build/ diff --git a/sdks/kotlin/codegen-tests/build.gradle.kts b/sdks/kotlin/codegen-tests/build.gradle.kts new file mode 100644 index 00000000000..49b6a2756ce --- /dev/null +++ b/sdks/kotlin/codegen-tests/build.gradle.kts @@ -0,0 +1,25 @@ +import java.util.Properties + +plugins { + alias(libs.plugins.kotlinJvm) + alias(libs.plugins.spacetimedb) +} + +spacetimedb { + modulePath.set(layout.projectDirectory.dir("spacetimedb")) + val localProps = rootProject.file("local.properties").let { f -> + if (f.exists()) Properties().also { it.load(f.inputStream()) } else null + } + (providers.environmentVariable("SPACETIMEDB_CLI").orNull + ?: localProps?.getProperty("spacetimedb.cli")) + ?.let { cli.set(file(it)) } +} + +dependencies { + implementation(project(":spacetimedb-sdk")) + testImplementation(libs.kotlin.test) +} + +tasks.test { + useJUnitPlatform() +} diff --git a/sdks/kotlin/codegen-tests/spacetimedb/Cargo.lock b/sdks/kotlin/codegen-tests/spacetimedb/Cargo.lock new file mode 100644 index 00000000000..2e5689fe7cc --- /dev/null +++ b/sdks/kotlin/codegen-tests/spacetimedb/Cargo.lock @@ -0,0 +1,966 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "approx" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0e60b75072ecd4168020818c0107f2857bb6c4e64252d8d3983f6263b40a5c3" +dependencies = [ + "num-traits", +] + +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "blake3" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468ef7d57b3fb7e16b576e8377cdbde2320c60e1491e961d11da40fc4f02a2d" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", + "cpufeatures", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "castaway" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" +dependencies = [ + "rustversion", +] + +[[package]] +name = "cc" +version = "1.2.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "chrono" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +dependencies = [ + "num-traits", +] + +[[package]] +name = "codegen_test_kt" +version = "0.1.0" +dependencies = [ + "log", + "spacetimedb", +] + +[[package]] +name = "constant_time_eq" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" + +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "decorum" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "281759d3c8a14f5c3f0c49363be56810fcd7f910422f97f2db850c2920fde5cf" +dependencies = [ + "approx", + "num-traits", +] + +[[package]] +name = "derive_more" +version = "0.99.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version", + "syn", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "enum-as-inner" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "ethnum" +version = "1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca81e6b4777c89fd810c25a4be2b1bd93ea034fbe58e6a75216a34c6b82c539b" +dependencies = [ + "serde", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi 5.3.0", + "wasip2", +] + +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "humantime" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "keccak" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" +dependencies = [ + "cpufeatures", +] + +[[package]] +name = "lean_string" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a262b6ae1dd9c2d3cf7977a816578b03bf8fb60b61545c395880f95eefc5b24" +dependencies = [ + "castaway", + "itoa", + "ryu", +] + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "libc" +version = "0.2.183" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "nohash-hasher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.17", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + +[[package]] +name = "second-stack" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4904c83c6e51f1b9b08bfa5a86f35a51798e8307186e6f5513852210a219c0bb" + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "spacetimedb" +version = "2.1.0" +dependencies = [ + "anyhow", + "bytemuck", + "bytes", + "derive_more", + "getrandom 0.2.17", + "http", + "log", + "rand 0.8.5", + "scoped-tls", + "serde_json", + "spacetimedb-bindings-macro", + "spacetimedb-bindings-sys", + "spacetimedb-lib", + "spacetimedb-primitives", + "spacetimedb-query-builder", +] + +[[package]] +name = "spacetimedb-bindings-macro" +version = "2.1.0" +dependencies = [ + "heck 0.4.1", + "humantime", + "proc-macro2", + "quote", + "spacetimedb-primitives", + "syn", +] + +[[package]] +name = "spacetimedb-bindings-sys" +version = "2.1.0" +dependencies = [ + "spacetimedb-primitives", +] + +[[package]] +name = "spacetimedb-lib" +version = "2.1.0" +dependencies = [ + "anyhow", + "bitflags", + "blake3", + "chrono", + "derive_more", + "enum-as-inner", + "hex", + "itertools", + "log", + "spacetimedb-bindings-macro", + "spacetimedb-primitives", + "spacetimedb-sats", + "thiserror", +] + +[[package]] +name = "spacetimedb-primitives" +version = "2.1.0" +dependencies = [ + "bitflags", + "either", + "enum-as-inner", + "itertools", + "nohash-hasher", +] + +[[package]] +name = "spacetimedb-query-builder" +version = "2.1.0" +dependencies = [ + "spacetimedb-lib", +] + +[[package]] +name = "spacetimedb-sats" +version = "2.1.0" +dependencies = [ + "anyhow", + "arrayvec", + "bitflags", + "bytemuck", + "bytes", + "chrono", + "decorum", + "derive_more", + "enum-as-inner", + "ethnum", + "hex", + "itertools", + "lean_string", + "rand 0.9.2", + "second-stack", + "sha3", + "smallvec", + "spacetimedb-bindings-macro", + "spacetimedb-primitives", + "thiserror", + "uuid", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "uuid" +version = "1.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a68d3c8f01c0cfa54a75291d83601161799e4a89a39e0929f4b0354d88757a37" +dependencies = [ + "getrandom 0.4.2", +] + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck 0.5.0", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck 0.5.0", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "zerocopy" +version = "0.8.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efbb2a062be311f2ba113ce66f697a4dc589f85e78a4aea276200804cea0ed87" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e8bc7269b54418e7aeeef514aa68f8690b8c0489a06b0136e5f57c4c5ccab89" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/sdks/kotlin/codegen-tests/spacetimedb/Cargo.toml b/sdks/kotlin/codegen-tests/spacetimedb/Cargo.toml new file mode 100644 index 00000000000..00a670d9d7b --- /dev/null +++ b/sdks/kotlin/codegen-tests/spacetimedb/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "codegen_test_kt" +version = "0.1.0" +edition = "2021" + +[workspace] + +[lib] +crate-type = ["cdylib"] + +[dependencies] +spacetimedb = { path = "/home/fromml/Projects/SpacetimeDB/crates/bindings" } +log.version = "0.4.17" diff --git a/sdks/kotlin/codegen-tests/spacetimedb/src/lib.rs b/sdks/kotlin/codegen-tests/spacetimedb/src/lib.rs new file mode 100644 index 00000000000..370303a2de9 --- /dev/null +++ b/sdks/kotlin/codegen-tests/spacetimedb/src/lib.rs @@ -0,0 +1,184 @@ +use spacetimedb::{ + ConnectionId, Identity, Query, ReducerContext, ScheduleAt, SpacetimeType, Table, Timestamp, + ViewContext, +}; +use spacetimedb::sats::{i256, u256}; + +// ───────────────────────────────────────────────────────────────────────────── +// PRODUCT TYPES +// ───────────────────────────────────────────────────────────────────────────── + +/// Empty product type — should generate `data object` in Kotlin. +#[derive(SpacetimeType)] +pub struct UnitStruct {} + +/// Product type with all primitive fields. +#[derive(SpacetimeType)] +pub struct AllPrimitives { + pub val_bool: bool, + pub val_i8: i8, + pub val_u8: u8, + pub val_i16: i16, + pub val_u16: u16, + pub val_i32: i32, + pub val_u32: u32, + pub val_i64: i64, + pub val_u64: u64, + pub val_i128: i128, + pub val_u128: u128, + pub val_i256: i256, + pub val_u256: u256, + pub val_f32: f32, + pub val_f64: f64, + pub val_string: String, + pub val_bytes: Vec, +} + +/// Product type with SDK-specific types. +#[derive(SpacetimeType)] +pub struct SdkTypes { + pub identity: Identity, + pub connection_id: ConnectionId, + pub timestamp: Timestamp, + pub schedule_at: ScheduleAt, +} + +/// Product type with optional and nested fields. +#[derive(SpacetimeType)] +pub struct NestedTypes { + pub optional_string: Option, + pub optional_i32: Option, + pub list_of_strings: Vec, + pub list_of_i32: Vec, + pub nested_struct: AllPrimitives, + pub optional_struct: Option, +} + +// ───────────────────────────────────────────────────────────────────────────── +// SUM TYPES (ENUMS) +// ───────────────────────────────────────────────────────────────────────────── + +/// Plain enum — all unit variants, should generate `enum class` in Kotlin. +#[derive(SpacetimeType)] +pub enum SimpleEnum { + Alpha, + Beta, + Gamma, +} + +/// Mixed sum type — should generate `sealed interface` in Kotlin. +#[derive(SpacetimeType)] +pub enum MixedEnum { + UnitVariant, + StringVariant(String), + IntVariant(i32), + StructVariant(AllPrimitives), +} + +// ───────────────────────────────────────────────────────────────────────────── +// TABLES +// ───────────────────────────────────────────────────────────────────────────── + +/// Table referencing the empty product type. +#[spacetimedb::table(accessor = unit_test_row, public)] +pub struct UnitTestRow { + #[primary_key] + #[auto_inc] + id: u64, + value: UnitStruct, +} + +/// Table with all primitive types — verifies full type mapping. +#[spacetimedb::table(accessor = all_types_row, public)] +pub struct AllTypesRow { + #[primary_key] + #[auto_inc] + id: u64, + primitives: AllPrimitives, + sdk_types: SdkTypes, +} + +/// Table with optional/nested fields. +#[spacetimedb::table(accessor = nested_row, public)] +pub struct NestedRow { + #[primary_key] + #[auto_inc] + id: u64, + data: NestedTypes, + tag: SimpleEnum, + payload: Option, +} + +/// Table with indexes — verifies UniqueIndex and BTreeIndex codegen. +#[spacetimedb::table( + accessor = indexed_row, + public, + index(accessor = name_idx, btree(columns = [name])) +)] +pub struct IndexedRow { + #[primary_key] + #[auto_inc] + id: u64, + #[unique] + code: String, + name: String, +} + +/// Table without primary key — verifies content-key table cache. +#[spacetimedb::table(accessor = no_pk_row, public)] +pub struct NoPkRow { + label: String, + value: i32, +} + +// ───────────────────────────────────────────────────────────────────────────── +// VIEWS +// ───────────────────────────────────────────────────────────────────────────── + +/// Query-builder view over a PK table — should inherit primary key and generate +/// `RemotePersistentTableWithPrimaryKey` with `onUpdate` callbacks. +#[spacetimedb::view(accessor = all_indexed_rows, public)] +fn all_indexed_rows(ctx: &ViewContext) -> impl Query { + ctx.from.indexed_row() +} + +// ───────────────────────────────────────────────────────────────────────────── +// REDUCERS +// ───────────────────────────────────────────────────────────────────────────── + +#[spacetimedb::reducer(init)] +pub fn init(_ctx: &ReducerContext) {} + +/// No-arg reducer. +#[spacetimedb::reducer] +pub fn do_nothing(_ctx: &ReducerContext) {} + +/// Reducer with multiple typed args. +#[spacetimedb::reducer] +pub fn insert_all_types( + ctx: &ReducerContext, + primitives: AllPrimitives, + sdk_types: SdkTypes, +) { + ctx.db.all_types_row().insert(AllTypesRow { + id: 0, + primitives, + sdk_types, + }); +} + +/// Reducer with enum args. +#[spacetimedb::reducer] +pub fn insert_nested( + ctx: &ReducerContext, + data: NestedTypes, + tag: SimpleEnum, + payload: Option, +) { + ctx.db.nested_row().insert(NestedRow { + id: 0, + data, + tag, + payload, + }); +} diff --git a/sdks/kotlin/codegen-tests/src/test/kotlin/CodegenTest.kt b/sdks/kotlin/codegen-tests/src/test/kotlin/CodegenTest.kt new file mode 100644 index 00000000000..e4d4d41f61d --- /dev/null +++ b/sdks/kotlin/codegen-tests/src/test/kotlin/CodegenTest.kt @@ -0,0 +1,39 @@ +import module_bindings.UnitStruct +import module_bindings.UnitTestRow +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertSame + +class CodegenTest { + + @Test + fun `empty product type is data object`() { + // data object: equals by identity, singleton + assertSame(UnitStruct, UnitStruct) + assertEquals(UnitStruct.toString(), "UnitStruct") + } + + @Test + fun `empty product type round trips`() { + val writer = BsatnWriter() + UnitStruct.encode(writer) + val bytes = writer.toByteArray() + // Empty struct encodes to zero bytes + assertEquals(0, bytes.size) + + val decoded = UnitStruct.decode(BsatnReader(bytes)) + assertSame(UnitStruct, decoded) + } + + @Test + fun `table with empty product type round trips`() { + val row = UnitTestRow(id = 42u, value = UnitStruct) + val writer = BsatnWriter() + row.encode(writer) + val decoded = UnitTestRow.decode(BsatnReader(writer.toByteArray())) + assertEquals(row.id, decoded.id) + assertSame(UnitStruct, decoded.value) + } +} diff --git a/sdks/kotlin/gradle.properties b/sdks/kotlin/gradle.properties new file mode 100644 index 00000000000..17b3929474f --- /dev/null +++ b/sdks/kotlin/gradle.properties @@ -0,0 +1,13 @@ +#Kotlin +kotlin.code.style=official +kotlin.daemon.jvmargs=-Xmx3072M +kotlin.native.ignoreDisabledTargets=true + +#Gradle +org.gradle.jvmargs=-Xmx4096M -Dfile.encoding=UTF-8 +org.gradle.configuration-cache=true +org.gradle.caching=true + +#Android +android.nonTransitiveRClass=true +android.useAndroidX=true diff --git a/sdks/kotlin/gradle/libs.versions.toml b/sdks/kotlin/gradle/libs.versions.toml new file mode 100644 index 00000000000..d69cde45fa0 --- /dev/null +++ b/sdks/kotlin/gradle/libs.versions.toml @@ -0,0 +1,29 @@ +[versions] +agp = "9.1.0" +android-compileSdk = "36" +android-minSdk = "26" +kotlin = "2.3.10" +kotlinx-coroutines = "1.10.2" +kotlinxAtomicfu = "0.31.0" +kotlinxCollectionsImmutable = "0.4.0" +ktor = "3.4.1" +brotli = "0.1.2" + +[libraries] +kotlinx-atomicfu = { module = "org.jetbrains.kotlinx:atomicfu", version.ref = "kotlinxAtomicfu" } +kotlinx-collections-immutable = { module = "org.jetbrains.kotlinx:kotlinx-collections-immutable", version.ref = "kotlinxCollectionsImmutable" } + +ktor-client-core = { module = "io.ktor:ktor-client-core", version.ref = "ktor" } +ktor-client-okhttp = { module = "io.ktor:ktor-client-okhttp", version.ref = "ktor" } +ktor-client-websockets = { module = "io.ktor:ktor-client-websockets", version.ref = "ktor" } + +brotli-dec = { module = "org.brotli:dec", version.ref = "brotli" } +kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" } +kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "kotlinx-coroutines" } +kotlinx-coroutines-test = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-test", version.ref = "kotlinx-coroutines" } + +[plugins] +kotlinJvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "kotlin" } +kotlinMultiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" } +androidKotlinMultiplatformLibrary = { id = "com.android.kotlin.multiplatform.library", version.ref = "agp" } +spacetimedb = { id = "com.clockworklabs.spacetimedb" } diff --git a/sdks/kotlin/gradle/wrapper/gradle-wrapper.jar b/sdks/kotlin/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 00000000000..2c3521197d7 Binary files /dev/null and b/sdks/kotlin/gradle/wrapper/gradle-wrapper.jar differ diff --git a/sdks/kotlin/gradle/wrapper/gradle-wrapper.properties b/sdks/kotlin/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000000..37f78a6af83 --- /dev/null +++ b/sdks/kotlin/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.3.1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/sdks/kotlin/gradlew b/sdks/kotlin/gradlew new file mode 100755 index 00000000000..f5feea6d6b1 --- /dev/null +++ b/sdks/kotlin/gradlew @@ -0,0 +1,252 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s +' "$PWD" ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/sdks/kotlin/gradlew.bat b/sdks/kotlin/gradlew.bat new file mode 100644 index 00000000000..9b42019c791 --- /dev/null +++ b/sdks/kotlin/gradlew.bat @@ -0,0 +1,94 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/sdks/kotlin/integration-tests/.gitignore b/sdks/kotlin/integration-tests/.gitignore new file mode 100644 index 00000000000..567609b1234 --- /dev/null +++ b/sdks/kotlin/integration-tests/.gitignore @@ -0,0 +1 @@ +build/ diff --git a/sdks/kotlin/integration-tests/build.gradle.kts b/sdks/kotlin/integration-tests/build.gradle.kts new file mode 100644 index 00000000000..ecb3432e084 --- /dev/null +++ b/sdks/kotlin/integration-tests/build.gradle.kts @@ -0,0 +1,41 @@ +import java.util.Properties + +plugins { + alias(libs.plugins.kotlinJvm) + alias(libs.plugins.spacetimedb) +} + +spacetimedb { + modulePath.set(layout.projectDirectory.dir("spacetimedb")) + val localProps = rootProject.file("local.properties").let { f -> + if (f.exists()) Properties().also { it.load(f.inputStream()) } else null + } + (providers.environmentVariable("SPACETIMEDB_CLI").orNull + ?: localProps?.getProperty("spacetimedb.cli")) + ?.let { cli.set(file(it)) } +} + +kotlin { + sourceSets.all { + languageSettings { + optIn("com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi") + } + } +} + +dependencies { + implementation(project(":spacetimedb-sdk")) + testImplementation(libs.kotlin.test) + testImplementation(libs.ktor.client.okhttp) + testImplementation(libs.ktor.client.websockets) + testImplementation(libs.kotlinx.coroutines.core) +} + +val integrationEnabled = providers.gradleProperty("integrationTests").isPresent + || providers.environmentVariable("SPACETIMEDB_HOST").isPresent + +tasks.test { + useJUnitPlatform() + testLogging.exceptionFormat = org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL + enabled = integrationEnabled +} diff --git a/sdks/kotlin/integration-tests/spacetimedb/Cargo.lock b/sdks/kotlin/integration-tests/spacetimedb/Cargo.lock new file mode 100644 index 00000000000..d9d51f7e5bd --- /dev/null +++ b/sdks/kotlin/integration-tests/spacetimedb/Cargo.lock @@ -0,0 +1,966 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "approx" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0e60b75072ecd4168020818c0107f2857bb6c4e64252d8d3983f6263b40a5c3" +dependencies = [ + "num-traits", +] + +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "blake3" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468ef7d57b3fb7e16b576e8377cdbde2320c60e1491e961d11da40fc4f02a2d" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", + "cpufeatures", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "castaway" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" +dependencies = [ + "rustversion", +] + +[[package]] +name = "cc" +version = "1.2.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "chat_kt" +version = "0.1.0" +dependencies = [ + "log", + "spacetimedb", +] + +[[package]] +name = "chrono" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +dependencies = [ + "num-traits", +] + +[[package]] +name = "constant_time_eq" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" + +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "decorum" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "281759d3c8a14f5c3f0c49363be56810fcd7f910422f97f2db850c2920fde5cf" +dependencies = [ + "approx", + "num-traits", +] + +[[package]] +name = "derive_more" +version = "0.99.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version", + "syn", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "enum-as-inner" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "ethnum" +version = "1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca81e6b4777c89fd810c25a4be2b1bd93ea034fbe58e6a75216a34c6b82c539b" +dependencies = [ + "serde", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi 5.3.0", + "wasip2", +] + +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "humantime" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "keccak" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" +dependencies = [ + "cpufeatures", +] + +[[package]] +name = "lean_string" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "962df00ba70ac8d5ca5c064e17e5c3d090c087fd8d21aa45096c716b169da514" +dependencies = [ + "castaway", + "itoa", + "ryu", +] + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "libc" +version = "0.2.182" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "nohash-hasher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.17", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + +[[package]] +name = "second-stack" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4904c83c6e51f1b9b08bfa5a86f35a51798e8307186e6f5513852210a219c0bb" + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "spacetimedb" +version = "2.1.0" +dependencies = [ + "anyhow", + "bytemuck", + "bytes", + "derive_more", + "getrandom 0.2.17", + "http", + "log", + "rand 0.8.5", + "scoped-tls", + "serde_json", + "spacetimedb-bindings-macro", + "spacetimedb-bindings-sys", + "spacetimedb-lib", + "spacetimedb-primitives", + "spacetimedb-query-builder", +] + +[[package]] +name = "spacetimedb-bindings-macro" +version = "2.1.0" +dependencies = [ + "heck 0.4.1", + "humantime", + "proc-macro2", + "quote", + "spacetimedb-primitives", + "syn", +] + +[[package]] +name = "spacetimedb-bindings-sys" +version = "2.1.0" +dependencies = [ + "spacetimedb-primitives", +] + +[[package]] +name = "spacetimedb-lib" +version = "2.1.0" +dependencies = [ + "anyhow", + "bitflags", + "blake3", + "chrono", + "derive_more", + "enum-as-inner", + "hex", + "itertools", + "log", + "spacetimedb-bindings-macro", + "spacetimedb-primitives", + "spacetimedb-sats", + "thiserror", +] + +[[package]] +name = "spacetimedb-primitives" +version = "2.1.0" +dependencies = [ + "bitflags", + "either", + "enum-as-inner", + "itertools", + "nohash-hasher", +] + +[[package]] +name = "spacetimedb-query-builder" +version = "2.1.0" +dependencies = [ + "spacetimedb-lib", +] + +[[package]] +name = "spacetimedb-sats" +version = "2.1.0" +dependencies = [ + "anyhow", + "arrayvec", + "bitflags", + "bytemuck", + "bytes", + "chrono", + "decorum", + "derive_more", + "enum-as-inner", + "ethnum", + "hex", + "itertools", + "lean_string", + "rand 0.9.2", + "second-stack", + "sha3", + "smallvec", + "spacetimedb-bindings-macro", + "spacetimedb-primitives", + "thiserror", + "uuid", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "uuid" +version = "1.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a68d3c8f01c0cfa54a75291d83601161799e4a89a39e0929f4b0354d88757a37" +dependencies = [ + "getrandom 0.4.2", +] + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck 0.5.0", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck 0.5.0", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "zerocopy" +version = "0.8.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a789c6e490b576db9f7e6b6d661bcc9799f7c0ac8352f56ea20193b2681532e5" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f65c489a7071a749c849713807783f70672b28094011623e200cb86dcb835953" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/sdks/kotlin/integration-tests/spacetimedb/Cargo.toml b/sdks/kotlin/integration-tests/spacetimedb/Cargo.toml new file mode 100644 index 00000000000..ba9e8126056 --- /dev/null +++ b/sdks/kotlin/integration-tests/spacetimedb/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "chat_kt" +version = "0.1.0" +edition = "2021" + +[workspace] + +[lib] +crate-type = ["cdylib"] + +[dependencies] +spacetimedb = { path = "/home/fromml/Projects/SpacetimeDB/crates/bindings", features = ["unstable"] } +log.version = "0.4.17" diff --git a/sdks/kotlin/integration-tests/spacetimedb/rust-toolchain.toml b/sdks/kotlin/integration-tests/spacetimedb/rust-toolchain.toml new file mode 100644 index 00000000000..28f0403f3d4 --- /dev/null +++ b/sdks/kotlin/integration-tests/spacetimedb/rust-toolchain.toml @@ -0,0 +1,8 @@ +[toolchain] +# change crates/{standalone,bench}/Dockerfile, .github/Dockerfile, and the docker image tag in +# .github/workflows/benchmarks.yml:jobs/callgrind_benchmark/container/image +# maybe also the rust-version in Cargo.toml +channel = "1.93.0" +profile = "default" +targets = ["wasm32-unknown-unknown"] +components = ["rust-src"] diff --git a/sdks/kotlin/integration-tests/spacetimedb/src/lib.rs b/sdks/kotlin/integration-tests/spacetimedb/src/lib.rs new file mode 100644 index 00000000000..0204d4f6423 --- /dev/null +++ b/sdks/kotlin/integration-tests/spacetimedb/src/lib.rs @@ -0,0 +1,257 @@ +use spacetimedb::{Identity, ProcedureContext, ReducerContext, ScheduleAt, Table, Timestamp}; +use spacetimedb::sats::{i256, u256}; + +#[spacetimedb::table(accessor = user, public)] +pub struct User { + #[primary_key] + identity: Identity, + name: Option, + online: bool, +} + +#[spacetimedb::table(accessor = message, public)] +pub struct Message { + #[auto_inc] + #[primary_key] + id: u64, + sender: Identity, + sent: Timestamp, + text: String, +} + +/// A simple note table — used to test onDelete and filtered subscriptions. +#[spacetimedb::table(accessor = note, public)] +pub struct Note { + #[auto_inc] + #[primary_key] + id: u64, + owner: Identity, + content: String, + tag: String, +} + +/// Scheduled table — tests ScheduleAt and TimeDuration types. +/// When a row's scheduled_at time arrives, the server calls send_reminder. +#[spacetimedb::table(accessor = reminder, public, scheduled(send_reminder))] +pub struct Reminder { + #[primary_key] + #[auto_inc] + scheduled_id: u64, + scheduled_at: ScheduleAt, + text: String, + owner: Identity, +} + +/// Table with large integer fields — tests Int128/UInt128/Int256/UInt256 codegen. +#[spacetimedb::table(accessor = big_int_row, public)] +pub struct BigIntRow { + #[primary_key] + #[auto_inc] + id: u64, + val_i128: i128, + val_u128: u128, + val_i256: i256, + val_u256: u256, +} + +#[spacetimedb::reducer] +pub fn insert_big_ints( + ctx: &ReducerContext, + val_i128: i128, + val_u128: u128, + val_i256: i256, + val_u256: u256, +) -> Result<(), String> { + ctx.db.big_int_row().insert(BigIntRow { + id: 0, + val_i128, + val_u128, + val_i256, + val_u256, + }); + Ok(()) +} + +fn validate_name(name: String) -> Result { + if name.is_empty() { + Err("Names must not be empty".to_string()) + } else { + Ok(name) + } +} + +#[spacetimedb::reducer] +pub fn set_name(ctx: &ReducerContext, name: String) -> Result<(), String> { + let name = validate_name(name)?; + if let Some(user) = ctx.db.user().identity().find(ctx.sender()) { + log::info!("User {} sets name to {name}", ctx.sender()); + ctx.db.user().identity().update(User { + name: Some(name), + ..user + }); + Ok(()) + } else { + Err("Cannot set name for unknown user".to_string()) + } +} + +fn validate_message(text: String) -> Result { + if text.is_empty() { + Err("Messages must not be empty".to_string()) + } else { + Ok(text) + } +} + +#[spacetimedb::reducer] +pub fn send_message(ctx: &ReducerContext, text: String) -> Result<(), String> { + let text = validate_message(text)?; + log::info!("User {}: {text}", ctx.sender()); + ctx.db.message().insert(Message { + id: 0, + sender: ctx.sender(), + text, + sent: ctx.timestamp, + }); + Ok(()) +} + +#[spacetimedb::reducer] +pub fn delete_message(ctx: &ReducerContext, message_id: u64) -> Result<(), String> { + if let Some(msg) = ctx.db.message().id().find(message_id) { + if msg.sender != ctx.sender() { + return Err("Cannot delete another user's message".to_string()); + } + ctx.db.message().id().delete(message_id); + log::info!("User {} deleted message {message_id}", ctx.sender()); + Ok(()) + } else { + Err("Message not found".to_string()) + } +} + +#[spacetimedb::reducer] +pub fn add_note(ctx: &ReducerContext, content: String, tag: String) -> Result<(), String> { + if content.is_empty() { + return Err("Note content must not be empty".to_string()); + } + ctx.db.note().insert(Note { + id: 0, + owner: ctx.sender(), + content, + tag, + }); + Ok(()) +} + +#[spacetimedb::reducer] +pub fn delete_note(ctx: &ReducerContext, note_id: u64) -> Result<(), String> { + if let Some(note) = ctx.db.note().id().find(note_id) { + if note.owner != ctx.sender() { + return Err("Cannot delete another user's note".to_string()); + } + ctx.db.note().id().delete(note_id); + Ok(()) + } else { + Err("Note not found".to_string()) + } +} + +/// Schedule a one-shot reminder that fires after delay_ms milliseconds. +#[spacetimedb::reducer] +pub fn schedule_reminder(ctx: &ReducerContext, text: String, delay_ms: u64) -> Result<(), String> { + if text.is_empty() { + return Err("Reminder text must not be empty".to_string()); + } + let at = ctx.timestamp + std::time::Duration::from_millis(delay_ms); + ctx.db.reminder().insert(Reminder { + scheduled_id: 0, + scheduled_at: ScheduleAt::Time(at), + text: text.clone(), + owner: ctx.sender(), + }); + log::info!("User {} scheduled reminder in {delay_ms}ms: {text}", ctx.sender()); + Ok(()) +} + +/// Schedule a repeating reminder that fires every interval_ms milliseconds. +#[spacetimedb::reducer] +pub fn schedule_reminder_repeat(ctx: &ReducerContext, text: String, interval_ms: u64) -> Result<(), String> { + if text.is_empty() { + return Err("Reminder text must not be empty".to_string()); + } + let interval = std::time::Duration::from_millis(interval_ms); + ctx.db.reminder().insert(Reminder { + scheduled_id: 0, + scheduled_at: interval.into(), + text: text.clone(), + owner: ctx.sender(), + }); + log::info!("User {} scheduled repeating reminder every {interval_ms}ms: {text}", ctx.sender()); + Ok(()) +} + +/// Cancel a scheduled reminder by id. +#[spacetimedb::reducer] +pub fn cancel_reminder(ctx: &ReducerContext, reminder_id: u64) -> Result<(), String> { + if let Some(reminder) = ctx.db.reminder().scheduled_id().find(reminder_id) { + if reminder.owner != ctx.sender() { + return Err("Cannot cancel another user's reminder".to_string()); + } + ctx.db.reminder().scheduled_id().delete(reminder_id); + log::info!("User {} cancelled reminder {reminder_id}", ctx.sender()); + Ok(()) + } else { + Err("Reminder not found".to_string()) + } +} + +/// Called by the scheduler when a reminder fires. +#[spacetimedb::reducer] +pub fn send_reminder(ctx: &ReducerContext, reminder: Reminder) { + log::info!("Reminder fired for {}: {}", reminder.owner, reminder.text); + // Insert a system message so the client sees it + ctx.db.message().insert(Message { + id: 0, + sender: reminder.owner, + text: format!("[REMINDER] {}", reminder.text), + sent: ctx.timestamp, + }); +} + +/// Simple procedure that echoes a greeting. +#[spacetimedb::procedure] +pub fn greet(_ctx: &mut ProcedureContext, name: String) -> String { + format!("Hello, {name}!") +} + +/// No-arg procedure that returns a constant. +#[spacetimedb::procedure] +pub fn server_ping(_ctx: &mut ProcedureContext) -> String { + "pong".to_string() +} + +#[spacetimedb::reducer(init)] +pub fn init(_ctx: &ReducerContext) {} + +#[spacetimedb::reducer(client_connected)] +pub fn identity_connected(ctx: &ReducerContext) { + if let Some(user) = ctx.db.user().identity().find(ctx.sender()) { + ctx.db.user().identity().update(User { online: true, ..user }); + } else { + ctx.db.user().insert(User { + name: None, + identity: ctx.sender(), + online: true, + }); + } +} + +#[spacetimedb::reducer(client_disconnected)] +pub fn identity_disconnected(ctx: &ReducerContext) { + if let Some(user) = ctx.db.user().identity().find(ctx.sender()) { + ctx.db.user().identity().update(User { online: false, ..user }); + } else { + log::warn!("Disconnect event for unknown user with identity {:?}", ctx.sender()); + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/BigIntTypeTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/BigIntTypeTest.kt new file mode 100644 index 00000000000..a0fef0692c9 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/BigIntTypeTest.kt @@ -0,0 +1,176 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Int128 +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Int256 +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.UInt128 +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.UInt256 +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.BigInteger +import module_bindings.BigIntRow +import module_bindings.InsertBigIntsArgs +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotEquals +import kotlin.test.assertTrue + +/** + * Integration tests for generated BigIntRow and InsertBigIntsArgs types + * that use Int128, UInt128, Int256, UInt256 value classes. + */ +class BigIntTypeTest { + + private val ONE = BigInteger.ONE + + // --- BigIntRow encode/decode round-trip --- + + @Test + fun `BigIntRow encode decode round-trip with zero values`() { + val row = BigIntRow( + id = 1UL, + valI128 = Int128.ZERO, + valU128 = UInt128.ZERO, + valI256 = Int256.ZERO, + valU256 = UInt256.ZERO, + ) + val decoded = encodeDecode(row) + assertEquals(row, decoded) + } + + @Test + fun `BigIntRow encode decode round-trip with max values`() { + val row = BigIntRow( + id = 42UL, + valI128 = Int128(ONE.shl(127) - ONE), // I128 max + valU128 = UInt128(ONE.shl(128) - ONE), // U128 max + valI256 = Int256(ONE.shl(255) - ONE), // I256 max + valU256 = UInt256(ONE.shl(256) - ONE), // U256 max + ) + val decoded = encodeDecode(row) + assertEquals(row, decoded) + } + + @Test + fun `BigIntRow encode decode round-trip with min signed values`() { + val row = BigIntRow( + id = 7UL, + valI128 = Int128(-ONE.shl(127)), // I128 min + valU128 = UInt128.ZERO, + valI256 = Int256(-ONE.shl(255)), // I256 min + valU256 = UInt256.ZERO, + ) + val decoded = encodeDecode(row) + assertEquals(row, decoded) + } + + @Test + fun `BigIntRow encode decode round-trip with small values`() { + val row = BigIntRow( + id = 3UL, + valI128 = Int128(BigInteger(-999)), + valU128 = UInt128(BigInteger(12345)), + valI256 = Int256(BigInteger(-67890)), + valU256 = UInt256(BigInteger(11111)), + ) + val decoded = encodeDecode(row) + assertEquals(row, decoded) + } + + // --- InsertBigIntsArgs encode/decode round-trip --- + + @Test + fun `InsertBigIntsArgs encode decode round-trip`() { + val args = InsertBigIntsArgs( + valI128 = Int128(BigInteger(42)), + valU128 = UInt128(BigInteger(100)), + valI256 = Int256(BigInteger(-200)), + valU256 = UInt256(BigInteger(300)), + ) + val bytes = args.encode() + val reader = BsatnReader(bytes) + val decoded = InsertBigIntsArgs.decode(reader) + assertEquals(0, reader.remaining, "All bytes should be consumed") + assertEquals(args, decoded) + } + + // --- BigIntRow data class equality --- + + @Test + fun `BigIntRow equals same values`() { + val a = makeBigIntRow(1UL, 42) + val b = makeBigIntRow(1UL, 42) + assertEquals(a, b) + } + + @Test + fun `BigIntRow not equals different i128`() { + val a = makeBigIntRow(1UL, 42) + val b = makeBigIntRow(1UL, 99) + assertNotEquals(a, b) + } + + @Test + fun `BigIntRow hashCode consistent with equals`() { + val a = makeBigIntRow(1UL, 42) + val b = makeBigIntRow(1UL, 42) + assertEquals(a.hashCode(), b.hashCode()) + } + + @Test + fun `BigIntRow toString contains field values`() { + val row = makeBigIntRow(5UL, 123) + val str = row.toString() + assertTrue(str.contains("BigIntRow"), "toString should contain class name: $str") + assertTrue(str.contains("5"), "toString should contain id: $str") + } + + @Test + fun `BigIntRow copy preserves unchanged fields`() { + val original = makeBigIntRow(1UL, 42) + val copy = original.copy(id = 99UL) + assertEquals(99UL, copy.id) + assertEquals(original.valI128, copy.valI128) + assertEquals(original.valU128, copy.valU128) + assertEquals(original.valI256, copy.valI256) + assertEquals(original.valU256, copy.valU256) + } + + @Test + fun `BigIntRow destructuring`() { + val row = makeBigIntRow(10UL, 77) + val (id, valI128, valU128, valI256, valU256) = row + assertEquals(10UL, id) + assertEquals(Int128(BigInteger(77)), valI128) + assertEquals(UInt128(BigInteger(77)), valU128) + assertEquals(Int256(BigInteger(77)), valI256) + assertEquals(UInt256(BigInteger(77)), valU256) + } + + // --- Value class type safety --- + + @Test + fun `value classes are distinct types`() { + val i128 = Int128(BigInteger(42)) + val u128 = UInt128(BigInteger(42)) + assertNotEquals(i128, u128) + } + + // --- Helpers --- + + private fun encodeDecode(row: BigIntRow): BigIntRow { + val writer = BsatnWriter() + row.encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val decoded = BigIntRow.decode(reader) + assertEquals(0, reader.remaining, "All bytes should be consumed") + return decoded + } + + private fun makeBigIntRow(id: ULong, v: Int): BigIntRow = BigIntRow( + id = id, + valI128 = Int128(BigInteger(v)), + valU128 = UInt128(BigInteger(v)), + valI256 = Int256(BigInteger(v)), + valU256 = UInt256(BigInteger(v)), + ) +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/BsatnRoundtripTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/BsatnRoundtripTest.kt new file mode 100644 index 00000000000..18383f3bb3b --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/BsatnRoundtripTest.kt @@ -0,0 +1,461 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ScheduleAt +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import module_bindings.Message +import module_bindings.Note +import module_bindings.Reminder +import module_bindings.User +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +/** + * BSATN binary serialization roundtrip tests. + */ +class BsatnRoundtripTest { + + // --- Primitive type roundtrips --- + + @Test + fun `bool roundtrip`() { + for (value in listOf(true, false)) { + val writer = BsatnWriter() + writer.writeBool(value) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(value, reader.readBool()) + } + } + + @Test + fun `byte and ubyte roundtrip`() { + val writer = BsatnWriter() + writer.writeByte(0x7F) + writer.writeU8(0xFFu.toUByte()) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(0x7F.toByte(), reader.readByte()) + assertEquals(0xFFu.toUByte(), reader.readU8()) + } + + @Test + fun `i8 roundtrip`() { + for (value in listOf(Byte.MIN_VALUE, 0.toByte(), Byte.MAX_VALUE)) { + val writer = BsatnWriter() + writer.writeI8(value) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(value, reader.readI8()) + } + } + + @Test + fun `u8 roundtrip`() { + for (value in listOf(UByte.MIN_VALUE, UByte.MAX_VALUE)) { + val writer = BsatnWriter() + writer.writeU8(value) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(value, reader.readU8()) + } + } + + @Test + fun `i16 roundtrip`() { + for (value in listOf(Short.MIN_VALUE, 0.toShort(), Short.MAX_VALUE)) { + val writer = BsatnWriter() + writer.writeI16(value) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(value, reader.readI16()) + } + } + + @Test + fun `u16 roundtrip`() { + for (value in listOf(UShort.MIN_VALUE, UShort.MAX_VALUE)) { + val writer = BsatnWriter() + writer.writeU16(value) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(value, reader.readU16()) + } + } + + @Test + fun `i32 roundtrip`() { + for (value in listOf(Int.MIN_VALUE, 0, Int.MAX_VALUE)) { + val writer = BsatnWriter() + writer.writeI32(value) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(value, reader.readI32()) + } + } + + @Test + fun `u32 roundtrip`() { + for (value in listOf(UInt.MIN_VALUE, UInt.MAX_VALUE)) { + val writer = BsatnWriter() + writer.writeU32(value) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(value, reader.readU32()) + } + } + + @Test + fun `i64 roundtrip`() { + for (value in listOf(Long.MIN_VALUE, 0L, Long.MAX_VALUE)) { + val writer = BsatnWriter() + writer.writeI64(value) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(value, reader.readI64()) + } + } + + @Test + fun `u64 roundtrip`() { + for (value in listOf(ULong.MIN_VALUE, ULong.MAX_VALUE)) { + val writer = BsatnWriter() + writer.writeU64(value) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(value, reader.readU64()) + } + } + + @Test + fun `f32 roundtrip`() { + for (value in listOf(0.0f, 1.5f, -3.14f, Float.MIN_VALUE, Float.MAX_VALUE)) { + val writer = BsatnWriter() + writer.writeF32(value) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(value, reader.readF32()) + } + } + + @Test + fun `f64 roundtrip`() { + for (value in listOf(0.0, 2.718281828, -1.0e100, Double.MIN_VALUE, Double.MAX_VALUE)) { + val writer = BsatnWriter() + writer.writeF64(value) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(value, reader.readF64()) + } + } + + @Test + fun `string roundtrip`() { + for (value in listOf("", "hello", "O'Reilly", "emoji: \uD83D\uDE00", "line\nnewline")) { + val writer = BsatnWriter() + writer.writeString(value) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(value, reader.readString()) + } + } + + @Test + fun `bytearray roundtrip`() { + val value = byteArrayOf(0, 1, 127, -128, -1) + val writer = BsatnWriter() + writer.writeByteArray(value) + val reader = BsatnReader(writer.toByteArray()) + assertTrue(value.contentEquals(reader.readByteArray())) + } + + // --- Multiple values in sequence --- + + @Test + fun `multiple primitives in sequence`() { + val writer = BsatnWriter() + writer.writeBool(true) + writer.writeI32(42) + writer.writeU64(999UL) + writer.writeString("test") + writer.writeF64(3.14) + + val reader = BsatnReader(writer.toByteArray()) + assertEquals(true, reader.readBool()) + assertEquals(42, reader.readI32()) + assertEquals(999UL, reader.readU64()) + assertEquals("test", reader.readString()) + assertEquals(3.14, reader.readF64()) + } + + // --- SDK type roundtrips --- + + @Test + fun `Identity encode-decode roundtrip`() { + val original = Identity.fromHexString("ab".repeat(32)) + val writer = BsatnWriter() + original.encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val decoded = Identity.decode(reader) + assertEquals(original, decoded) + } + + @Test + fun `Identity zero encode-decode roundtrip`() { + val original = Identity.zero() + val writer = BsatnWriter() + original.encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val decoded = Identity.decode(reader) + assertEquals(original, decoded) + assertEquals("00".repeat(32), decoded.toHexString()) + } + + @Test + fun `ConnectionId encode-decode roundtrip`() { + val original = ConnectionId.random() + val writer = BsatnWriter() + original.encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val decoded = ConnectionId.decode(reader) + assertEquals(original, decoded) + } + + @Test + fun `ConnectionId zero encode-decode roundtrip`() { + val original = ConnectionId.zero() + val writer = BsatnWriter() + original.encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val decoded = ConnectionId.decode(reader) + assertEquals(original, decoded) + } + + @Test + fun `Timestamp encode-decode roundtrip`() { + val original = Timestamp.now() + val writer = BsatnWriter() + original.encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val decoded = Timestamp.decode(reader) + assertEquals(original.microsSinceUnixEpoch, decoded.microsSinceUnixEpoch) + } + + @Test + fun `Timestamp UNIX_EPOCH encode-decode roundtrip`() { + val original = Timestamp.UNIX_EPOCH + val writer = BsatnWriter() + original.encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val decoded = Timestamp.decode(reader) + assertEquals(original, decoded) + assertEquals(0L, decoded.microsSinceUnixEpoch) + } + + @Test + fun `ScheduleAt Time encode-decode roundtrip`() { + val original = ScheduleAt.Time(Timestamp.fromMillis(1700000000000L)) + val writer = BsatnWriter() + original.encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val decoded = ScheduleAt.decode(reader) + assertIs(decoded, "Should decode as Time") + assertEquals( + original.timestamp.microsSinceUnixEpoch, + decoded.timestamp.microsSinceUnixEpoch + ) + } + + @Test + fun `ScheduleAt Interval encode-decode roundtrip`() { + val original = ScheduleAt.interval(5.seconds) + val writer = BsatnWriter() + original.encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val decoded = ScheduleAt.decode(reader) + assertEquals(original, decoded) + } + + // --- Generated type roundtrips --- + + @Test + fun `User encode-decode roundtrip with name`() { + val original = User( + identity = Identity.fromHexString("ab".repeat(32)), + name = "Alice", + online = true + ) + val writer = BsatnWriter() + original.encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val decoded = User.decode(reader) + assertEquals(original, decoded) + } + + @Test + fun `User encode-decode roundtrip with null name`() { + val original = User( + identity = Identity.fromHexString("cd".repeat(32)), + name = null, + online = false + ) + val writer = BsatnWriter() + original.encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val decoded = User.decode(reader) + assertEquals(original, decoded) + } + + @Test + fun `Message encode-decode roundtrip`() { + val original = Message( + id = 42UL, + sender = Identity.fromHexString("ab".repeat(32)), + sent = Timestamp.fromMillis(1700000000000L), + text = "Hello, world!" + ) + val writer = BsatnWriter() + original.encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val decoded = Message.decode(reader) + assertEquals(original.id, decoded.id) + assertEquals(original.sender, decoded.sender) + assertEquals(original.sent.microsSinceUnixEpoch, decoded.sent.microsSinceUnixEpoch) + assertEquals(original.text, decoded.text) + } + + @Test + fun `Note encode-decode roundtrip`() { + val original = Note( + id = 7UL, + owner = Identity.fromHexString("ef".repeat(32)), + content = "Test note with special chars: O'Reilly & \"quotes\"", + tag = "test-tag" + ) + val writer = BsatnWriter() + original.encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val decoded = Note.decode(reader) + assertEquals(original, decoded) + } + + @Test + fun `Reminder encode-decode roundtrip`() { + val original = Reminder( + scheduledId = 100UL, + scheduledAt = ScheduleAt.interval(10.seconds), + text = "Don't forget!", + owner = Identity.fromHexString("11".repeat(32)) + ) + val writer = BsatnWriter() + original.encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val decoded = Reminder.decode(reader) + assertEquals(original, decoded) + } + + // --- Writer utilities --- + + @Test + fun `writer toByteArray returns correct length`() { + val writer = BsatnWriter() + writer.writeI32(42) + assertEquals(4, writer.toByteArray().size, "i32 should be 4 bytes") + } + + @Test + fun `writer toBase64 produces non-empty string`() { + val writer = BsatnWriter() + writer.writeString("hello") + val base64 = writer.toBase64() + assertTrue(base64.isNotEmpty(), "Base64 should not be empty") + } + + @Test + fun `writer reset clears data`() { + val writer = BsatnWriter() + writer.writeI32(42) + assertTrue(writer.toByteArray().isNotEmpty()) + writer.reset() + assertEquals(0, writer.toByteArray().size, "After reset, writer should be empty") + } + + @Test + fun `reader remaining tracks bytes left`() { + val writer = BsatnWriter() + writer.writeI32(10) + writer.writeI32(20) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(8, reader.remaining) + reader.readI32() + assertEquals(4, reader.remaining) + reader.readI32() + assertEquals(0, reader.remaining) + } + + @Test + fun `reader offset tracks position`() { + val writer = BsatnWriter() + writer.writeI32(10) + writer.writeI64(20L) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(0, reader.offset) + reader.readI32() + assertEquals(4, reader.offset) + reader.readI64() + assertEquals(12, reader.offset) + } + + // --- SumTag and ArrayLen --- + + @Test + fun `sumTag roundtrip`() { + for (tag in listOf(0u.toUByte(), 1u.toUByte(), 255u.toUByte())) { + val writer = BsatnWriter() + writer.writeSumTag(tag) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(tag, reader.readSumTag()) + } + } + + @Test + fun `arrayLen roundtrip`() { + for (len in listOf(0, 1, 100, 65535)) { + val writer = BsatnWriter() + writer.writeArrayLen(len) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(len, reader.readArrayLen()) + } + } + + // --- Little-endian byte order verification --- + + @Test + fun `i32 is little-endian`() { + val writer = BsatnWriter() + writer.writeI32(1) + val bytes = writer.toByteArray() + assertEquals(4, bytes.size) + // 1 in little-endian i32 = [0x01, 0x00, 0x00, 0x00] + assertEquals(0x01.toByte(), bytes[0]) + assertEquals(0x00.toByte(), bytes[1]) + assertEquals(0x00.toByte(), bytes[2]) + assertEquals(0x00.toByte(), bytes[3]) + } + + @Test + fun `u16 is little-endian`() { + val writer = BsatnWriter() + writer.writeU16(0x0102u.toUShort()) + val bytes = writer.toByteArray() + assertEquals(2, bytes.size) + // 0x0102 in little-endian = [0x02, 0x01] + assertEquals(0x02.toByte(), bytes[0]) + assertEquals(0x01.toByte(), bytes[1]) + } + + @Test + fun `f64 is little-endian IEEE 754`() { + val writer = BsatnWriter() + writer.writeF64(1.0) + val bytes = writer.toByteArray() + assertEquals(8, bytes.size) + // 1.0 as IEEE 754 double LE = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF0, 0x3F] + assertEquals(0x00.toByte(), bytes[0]) + assertEquals(0x3F.toByte(), bytes[7]) + assertEquals(0xF0.toByte(), bytes[6]) + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ColComparisonTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ColComparisonTest.kt new file mode 100644 index 00000000000..76f47729086 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ColComparisonTest.kt @@ -0,0 +1,175 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SqlLit +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.QueryBuilder +import module_bindings.addQuery +import module_bindings.db +import module_bindings.reducers +import kotlin.test.Test +import kotlin.test.assertTrue + +class ColComparisonTest { + + // --- SQL generation tests for lt/lte/gt/gte --- + + @Test + fun `Col lt generates correct SQL`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> c.id.lt(SqlLit.ulong(100uL)) }.toSql() + assertTrue(sql.contains("< 100"), "Should contain '< 100': $sql") + } + + @Test + fun `Col lte generates correct SQL`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> c.id.lte(SqlLit.ulong(100uL)) }.toSql() + assertTrue(sql.contains("<= 100"), "Should contain '<= 100': $sql") + } + + @Test + fun `Col gt generates correct SQL`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> c.id.gt(SqlLit.ulong(0uL)) }.toSql() + assertTrue(sql.contains("> 0"), "Should contain '> 0': $sql") + } + + @Test + fun `Col gte generates correct SQL`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> c.id.gte(SqlLit.ulong(1uL)) }.toSql() + assertTrue(sql.contains(">= 1"), "Should contain '>= 1': $sql") + } + + // --- Live subscribe tests --- + + @Test + fun `gt with live subscribe returns matching rows`() = runBlocking { + val client = connectToDb() + + // First subscribe to all notes to see what's there + client.subscribeAll() + + // Insert a note so we have at least one + val insertDone = CompletableDeferred() + client.conn.db.note.onInsert { ctx, note -> + if (ctx !is EventContext.SubscribeApplied + && note.owner == client.identity && note.tag == "gt-test") { + insertDone.complete(note.id) + } + } + client.conn.reducers.addNote("gt-content", "gt-test") + val noteId = withTimeout(DEFAULT_TIMEOUT_MS) { insertDone.await() } + + // Now create a second connection with a gt filter + val client2 = connectToDb() + val applied = CompletableDeferred() + + client2.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .addQuery { qb -> qb.note().where { c -> c.id.gte(SqlLit.ulong(noteId)) } } + .subscribe() + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + + val notes = client2.conn.db.note.all() + assertTrue(notes.all { it.id >= noteId }, "All notes should have id >= $noteId") + + client2.conn.disconnect() + client.cleanup() + } + + // --- Chained where().where() tests --- + + @Test + fun `chained where produces AND clause`() { + val qb = QueryBuilder() + val sql = qb.note() + .where { c -> c.tag.eq(SqlLit.string("test")) } + .where { c -> c.content.eq(SqlLit.string("hello")) } + .toSql() + assertTrue(sql.contains("AND"), "Chained where should produce AND: $sql") + assertTrue(sql.contains("tag"), "Should contain first where column: $sql") + assertTrue(sql.contains("content"), "Should contain second where column: $sql") + } + + @Test + fun `triple chained where produces two ANDs`() { + val qb = QueryBuilder() + val sql = qb.note() + .where { c -> c.tag.eq(SqlLit.string("a")) } + .where { c -> c.content.eq(SqlLit.string("b")) } + .where { c -> c.id.gt(SqlLit.ulong(0uL)) } + .toSql() + // Count AND occurrences + val andCount = Regex("AND").findAll(sql).count() + assertTrue(andCount >= 2, "Triple chain should have >= 2 ANDs, got $andCount: $sql") + } + + @Test + fun `chained where with live subscribe works`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + // Insert a note with known tag+content + val insertDone = CompletableDeferred() + client.conn.db.note.onInsert { ctx, note -> + if (ctx !is EventContext.SubscribeApplied + && note.owner == client.identity && note.tag == "chain-test") { + insertDone.complete(Unit) + } + } + client.conn.reducers.addNote("chain-content", "chain-test") + withTimeout(DEFAULT_TIMEOUT_MS) { insertDone.await() } + + // Second client subscribes with chained where + val client2 = connectToDb() + val applied = CompletableDeferred() + + client2.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .addQuery { qb -> + qb.note() + .where { c -> c.tag.eq(SqlLit.string("chain-test")) } + .where { c -> c.content.eq(SqlLit.string("chain-content")) } + } + .subscribe() + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + + val notes = client2.conn.db.note.all() + assertTrue(notes.isNotEmpty(), "Should have at least one note matching both where clauses") + assertTrue(notes.all { it.tag == "chain-test" && it.content == "chain-content" }, + "All notes should match both conditions") + + client2.conn.disconnect() + client.cleanup() + } + + // --- Col.eq with another Col (self-join condition) --- + + @Test + fun `Col eq with another Col generates column comparison SQL`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> c.tag.eq(c.content) }.toSql() + assertTrue(sql.contains("\"tag\"") && sql.contains("\"content\""), "Should reference both columns: $sql") + assertTrue(sql.contains("="), "Should have = operator: $sql") + } + + // --- filter alias on FromWhere --- + + @Test + fun `filter on FromWhere chains like where`() { + val qb = QueryBuilder() + val sql = qb.note() + .where { c -> c.tag.eq(SqlLit.string("a")) } + .filter { c -> c.content.eq(SqlLit.string("b")) } + .toSql() + assertTrue(sql.contains("AND"), "filter after where should also AND: $sql") + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ColExtensionsTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ColExtensionsTest.kt new file mode 100644 index 00000000000..cea507b2c88 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ColExtensionsTest.kt @@ -0,0 +1,147 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.* +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import module_bindings.QueryBuilder +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class ColExtensionsTest { + + // Test that convenience extensions produce the same SQL as explicit SqlLit calls + + // --- String extensions --- + + @Test + fun `String eq extension matches SqlLit eq`() { + val qb = QueryBuilder() + val withExt = qb.note().where { c -> c.tag.eq("hello") }.toSql() + val withLit = qb.note().where { c -> c.tag.eq(SqlLit.string("hello")) }.toSql() + assertEquals(withLit, withExt) + } + + @Test + fun `String neq extension matches SqlLit neq`() { + val qb = QueryBuilder() + val withExt = qb.note().where { c -> c.tag.neq("hello") }.toSql() + val withLit = qb.note().where { c -> c.tag.neq(SqlLit.string("hello")) }.toSql() + assertEquals(withLit, withExt) + } + + @Test + fun `String lt extension matches SqlLit lt`() { + val qb = QueryBuilder() + val withExt = qb.note().where { c -> c.tag.lt("z") }.toSql() + val withLit = qb.note().where { c -> c.tag.lt(SqlLit.string("z")) }.toSql() + assertEquals(withLit, withExt) + } + + @Test + fun `String lte extension`() { + val qb = QueryBuilder() + val withExt = qb.note().where { c -> c.tag.lte("z") }.toSql() + val withLit = qb.note().where { c -> c.tag.lte(SqlLit.string("z")) }.toSql() + assertEquals(withLit, withExt) + } + + @Test + fun `String gt extension matches SqlLit gt`() { + val qb = QueryBuilder() + val withExt = qb.note().where { c -> c.tag.gt("a") }.toSql() + val withLit = qb.note().where { c -> c.tag.gt(SqlLit.string("a")) }.toSql() + assertEquals(withLit, withExt) + } + + @Test + fun `String gte extension`() { + val qb = QueryBuilder() + val withExt = qb.note().where { c -> c.tag.gte("a") }.toSql() + val withLit = qb.note().where { c -> c.tag.gte(SqlLit.string("a")) }.toSql() + assertEquals(withLit, withExt) + } + + // --- Boolean extensions --- + + @Test + fun `Boolean eq extension matches SqlLit eq`() { + val qb = QueryBuilder() + val withExt = qb.user().where { c -> c.online.eq(true) }.toSql() + val withLit = qb.user().where { c -> c.online.eq(SqlLit.bool(true)) }.toSql() + assertEquals(withLit, withExt) + } + + @Test + fun `Boolean neq extension matches SqlLit neq`() { + val qb = QueryBuilder() + val withExt = qb.user().where { c -> c.online.neq(false) }.toSql() + val withLit = qb.user().where { c -> c.online.neq(SqlLit.bool(false)) }.toSql() + assertEquals(withLit, withExt) + } + + // --- NullableCol String extensions --- + + @Test + fun `NullableCol String eq extension`() { + val qb = QueryBuilder() + val withExt = qb.user().where { c -> c.name.eq("alice") }.toSql() + val withLit = qb.user().where { c -> c.name.eq(SqlLit.string("alice")) }.toSql() + assertEquals(withLit, withExt) + } + + @Test + fun `NullableCol String gte extension`() { + val qb = QueryBuilder() + val withExt = qb.user().where { c -> c.name.gte("a") }.toSql() + val withLit = qb.user().where { c -> c.name.gte(SqlLit.string("a")) }.toSql() + assertEquals(withLit, withExt) + } + + // --- ULong extensions (note.id is Col) --- + + @Test + fun `ULong eq extension`() { + val qb = QueryBuilder() + val withExt = qb.note().where { c -> c.id.eq(42uL) }.toSql() + val withLit = qb.note().where { c -> c.id.eq(SqlLit.ulong(42uL)) }.toSql() + assertEquals(withLit, withExt) + } + + @Test + fun `ULong lt extension`() { + val qb = QueryBuilder() + val withExt = qb.note().where { c -> c.id.lt(100uL) }.toSql() + val withLit = qb.note().where { c -> c.id.lt(SqlLit.ulong(100uL)) }.toSql() + assertEquals(withLit, withExt) + } + + @Test + fun `ULong gte extension`() { + val qb = QueryBuilder() + val withExt = qb.note().where { c -> c.id.gte(1uL) }.toSql() + val withLit = qb.note().where { c -> c.id.gte(SqlLit.ulong(1uL)) }.toSql() + assertEquals(withLit, withExt) + } + + // --- IxCol Identity extension (user identity is IxCol) --- + + @Test + fun `IxCol Identity eq extension`() { + val qb = QueryBuilder() + val id = Identity.zero() + val withExt = qb.user().where { _, ix -> ix.identity.eq(id) }.toSql() + val withLit = qb.user().where { _, ix -> ix.identity.eq(SqlLit.identity(id)) }.toSql() + assertEquals(withLit, withExt) + } + + // --- Verify convenience extensions produce valid SQL --- + + @Test + fun `convenience extensions produce valid SQL structure`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> c.tag.eq("test") }.toSql() + assertTrue(sql.contains("SELECT"), "Should be a SELECT: $sql") + assertTrue(sql.contains("WHERE"), "Should have WHERE: $sql") + assertTrue(sql.contains("'test'"), "Should contain quoted value: $sql") + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/CompressionTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/CompressionTest.kt new file mode 100644 index 00000000000..c74b448faa6 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/CompressionTest.kt @@ -0,0 +1,167 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.BigInteger +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.CompressionMode +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Int128 +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Int256 +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.UInt128 +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.UInt256 +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.onFailure +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.onSuccess +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.db +import module_bindings.procedures +import module_bindings.reducers +import module_bindings.withModuleBindings +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +/** + * Shared compression test logic. Each subclass sets the [mode] and + * all tests run end-to-end over that compression mode. + */ +abstract class CompressionTestBase(private val mode: CompressionMode) { + + private suspend fun connect(): ConnectedClient { + val identityDeferred = CompletableDeferred>() + + val conn = DbConnection.Builder() + .withHttpClient(createTestHttpClient()) + .withUri(HOST) + .withDatabaseName(DB_NAME) + .withCompression(mode) + .withModuleBindings() + .onConnect { _, identity, tok -> + identityDeferred.complete(identity to tok) + } + .onConnectError { _, e -> + identityDeferred.completeExceptionally(e) + } + .build() + + val (identity, tok) = withTimeout(DEFAULT_TIMEOUT_MS) { identityDeferred.await() } + return ConnectedClient(conn = conn, identity = identity, token = tok) + } + + @Test + fun `send message`() = runBlocking { + val client = connect() + client.subscribeAll() + + val text = "$mode-msg-${System.nanoTime()}" + val received = CompletableDeferred() + client.conn.db.message.onInsert { _, row -> + if (row.text == text) received.complete(row.text) + } + client.conn.reducers.sendMessage(text) + + assertEquals(text, withTimeout(DEFAULT_TIMEOUT_MS) { received.await() }) + client.cleanup() + } + + @Test + fun `set name`() = runBlocking { + val client = connect() + client.subscribeAll() + + val name = "$mode-user-${System.nanoTime()}" + val received = CompletableDeferred() + client.conn.db.user.onUpdate { _, _, newRow -> + if (newRow.identity == client.identity && newRow.name == name) { + received.complete(newRow.name) + } + } + client.conn.reducers.setName(name) + + assertEquals(name, withTimeout(DEFAULT_TIMEOUT_MS) { received.await() }) + client.cleanup() + } + + @Test + fun `insert big ints`() = runBlocking { + val client = connect() + client.subscribeAll() + + val one = BigInteger.ONE + val i128 = Int128(one.shl(100)) + val u128 = UInt128(one.shl(120)) + val i256 = Int256(one.shl(200)) + val u256 = UInt256(one.shl(250)) + + val received = CompletableDeferred() + client.conn.db.bigIntRow.onInsert { _, row -> + if (row.valI128 == i128 && row.valU128 == u128 && + row.valI256 == i256 && row.valU256 == u256 + ) { + received.complete(true) + } + } + client.conn.reducers.insertBigInts(i128, u128, i256, u256) + + assertTrue(withTimeout(DEFAULT_TIMEOUT_MS) { received.await() }) + client.cleanup() + } + + @Test + fun `add note`() = runBlocking { + val client = connect() + client.subscribeAll() + + val content = "$mode-note-${System.nanoTime()}" + val tag = "test-tag" + val received = CompletableDeferred() + client.conn.db.note.onInsert { _, row -> + if (row.content == content && row.tag == tag) { + received.complete(row.content) + } + } + client.conn.reducers.addNote(content, tag) + + assertEquals(content, withTimeout(DEFAULT_TIMEOUT_MS) { received.await() }) + client.cleanup() + } + + @Test + fun `call greet procedure`() = runBlocking { + val client = connect() + + val received = CompletableDeferred() + client.conn.procedures.greet("World") { _, result -> + result + .onSuccess { received.complete(it) } + .onFailure { received.completeExceptionally(Exception("$it")) } + } + + assertEquals("Hello, World!", withTimeout(DEFAULT_TIMEOUT_MS) { received.await() }) + client.conn.disconnect() + } + + @Test + fun `call server ping procedure`() = runBlocking { + val client = connect() + + val received = CompletableDeferred() + client.conn.procedures.serverPing { _, result -> + result + .onSuccess { received.complete(it) } + .onFailure { received.completeExceptionally(Exception("$it")) } + } + + assertEquals("pong", withTimeout(DEFAULT_TIMEOUT_MS) { received.await() }) + client.conn.disconnect() + } +} + +/** Tests with no compression. */ +class NoneCompressionTest : CompressionTestBase(CompressionMode.NONE) + +/** Tests with GZIP compression. */ +class GzipCompressionTest : CompressionTestBase(CompressionMode.GZIP) + +/** Tests with Brotli compression. */ +class BrotliCompressionTest : CompressionTestBase(CompressionMode.BROTLI) diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ConnectionIdTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ConnectionIdTest.kt new file mode 100644 index 00000000000..ef50f1dc002 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ConnectionIdTest.kt @@ -0,0 +1,156 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class ConnectionIdTest { + + // --- Factories --- + + @Test + fun `zero creates zero connectionId`() { + val id = ConnectionId.zero() + assertTrue(id.isZero(), "zero() should be zero") + assertEquals("0".repeat(32), id.toHexString(), "Zero connId should be 32 zeros") + } + + @Test + fun `random creates non-zero connectionId`() { + val id = ConnectionId.random() + assertTrue(!id.isZero(), "random() should not be zero") + } + + @Test + fun `random creates unique values`() { + val a = ConnectionId.random() + val b = ConnectionId.random() + assertNotEquals(a, b, "Two random connectionIds should differ") + } + + @Test + fun `fromHexString parses valid hex`() { + val hex = "ab".repeat(16) // 32 hex chars = 16 bytes = U128 + val id = ConnectionId.fromHexString(hex) + assertTrue(id.toHexString().contains("ab"), "Should contain ab") + } + + @Test + fun `fromHexString roundtrips`() { + val hex = "0123456789abcdef".repeat(2) // 32 hex chars + val id = ConnectionId.fromHexString(hex) + assertEquals(hex, id.toHexString()) + } + + @Test + fun `fromHexString rejects invalid hex`() { + assertFailsWith { + ConnectionId.fromHexString("not-hex!") + } + } + + @Test + fun `fromHexStringOrNull returns null for invalid hex`() { + val result = ConnectionId.fromHexStringOrNull("not-valid") + assertNull(result, "Invalid hex should return null") + } + + @Test + fun `fromHexStringOrNull returns null for zero hex`() { + val result = ConnectionId.fromHexStringOrNull("0".repeat(32)) + assertNull(result, "Zero hex should return null (nullIfZero)") + } + + @Test + fun `fromHexStringOrNull returns non-null for valid nonzero hex`() { + val result = ConnectionId.fromHexStringOrNull("ab".repeat(16)) + assertNotNull(result, "Valid nonzero hex should return non-null") + } + + // --- nullIfZero --- + + @Test + fun `nullIfZero returns null for zero`() { + assertNull(ConnectionId.nullIfZero(ConnectionId.zero())) + } + + @Test + fun `nullIfZero returns identity for nonzero`() { + val id = ConnectionId.random() + assertEquals(id, ConnectionId.nullIfZero(id)) + } + + // --- Conversions --- + + @Test + fun `toHexString returns 32 lowercase hex chars`() { + val id = ConnectionId.random() + val hex = id.toHexString() + assertEquals(32, hex.length, "Hex should be 32 chars: $hex") + assertTrue(hex.all { it in '0'..'9' || it in 'a'..'f' }, "Should be lowercase hex: $hex") + } + + @Test + fun `toByteArray returns 16 bytes`() { + val id = ConnectionId.random() + assertEquals(16, id.toByteArray().size) + } + + @Test + fun `zero toByteArray is all zeros`() { + val bytes = ConnectionId.zero().toByteArray() + assertTrue(bytes.all { it == 0.toByte() }, "Zero bytes should all be 0") + } + + @Test + fun `toString equals toHexString`() { + val id = ConnectionId.random() + assertEquals(id.toHexString(), id.toString()) + } + + // --- isZero --- + + @Test + fun `isZero true for zero`() { + assertTrue(ConnectionId.zero().isZero()) + } + + @Test + fun `isZero false for random`() { + assertTrue(!ConnectionId.random().isZero()) + } + + // --- equals / hashCode --- + + @Test + fun `equal connectionIds have same hashCode`() { + val hex = "ab".repeat(16) + val a = ConnectionId.fromHexString(hex) + val b = ConnectionId.fromHexString(hex) + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + } + + // --- Live connectionId from connection --- + + @Test + fun `connectionId from connection is non-null`() = kotlinx.coroutines.runBlocking { + val client = connectToDb() + assertNotNull(client.conn.connectionId, "connectionId should be non-null after connect") + client.conn.disconnect() + } + + @Test + fun `connectionId from connection has valid hex`() = kotlinx.coroutines.runBlocking { + val client = connectToDb() + val hex = client.conn.connectionId!!.toHexString() + assertEquals(32, hex.length) + assertTrue(hex.all { it in '0'..'9' || it in 'a'..'f' }) + client.conn.disconnect() + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/DbConnectionBuilderErrorTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/DbConnectionBuilderErrorTest.kt new file mode 100644 index 00000000000..57971d6e6b6 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/DbConnectionBuilderErrorTest.kt @@ -0,0 +1,103 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.withModuleBindings +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class DbConnectionBuilderErrorTest { + + @Test + fun `build with invalid URI fires onConnectError`() = runBlocking { + val error = CompletableDeferred() + + DbConnection.Builder() + .withHttpClient(createTestHttpClient()) + .withUri("ws://localhost:99999") + .withDatabaseName(DB_NAME) + .withModuleBindings() + .onConnect { _, _, _ -> error.completeExceptionally(AssertionError("Should not connect")) } + .onConnectError { _, e -> error.complete(e) } + .build() + + val ex = withTimeout(DEFAULT_TIMEOUT_MS) { error.await() } + assertNotNull(ex, "Should receive an error on invalid URI") + Unit + } + + @Test + fun `build with unreachable host fires onConnectError`() = runBlocking { + val error = CompletableDeferred() + + DbConnection.Builder() + .withHttpClient(createTestHttpClient()) + .withUri("ws://192.0.2.1:3000") + .withDatabaseName(DB_NAME) + .withModuleBindings() + .onConnect { _, _, _ -> error.completeExceptionally(AssertionError("Should not connect")) } + .onConnectError { _, e -> error.complete(e) } + .build() + + val ex = withTimeout(DEFAULT_TIMEOUT_MS) { error.await() } + assertNotNull(ex, "Should receive an error on unreachable host") + Unit + } + + @Test + fun `build with invalid database name fires onConnectError`() = runBlocking { + val error = CompletableDeferred() + + DbConnection.Builder() + .withHttpClient(createTestHttpClient()) + .withUri(HOST) + .withDatabaseName("nonexistent-db-${System.nanoTime()}") + .withModuleBindings() + .onConnect { _, _, _ -> error.completeExceptionally(AssertionError("Should not connect")) } + .onConnectError { _, e -> error.complete(e) } + .build() + + val ex = withTimeout(DEFAULT_TIMEOUT_MS) { error.await() } + assertNotNull(ex, "Should receive an error on invalid database name") + Unit + } + + @Test + fun `isActive is false after connect error`() = runBlocking { + val error = CompletableDeferred() + + val conn = DbConnection.Builder() + .withHttpClient(createTestHttpClient()) + .withUri("ws://localhost:99999") + .withDatabaseName(DB_NAME) + .withModuleBindings() + .onConnectError { _, e -> error.complete(e) } + .build() + + withTimeout(DEFAULT_TIMEOUT_MS) { error.await() } + assertTrue(!conn.isActive, "isActive should be false after connect error") + } + + @Test + fun `build with garbage token fires onConnectError`() = runBlocking { + val error = CompletableDeferred() + + DbConnection.Builder() + .withHttpClient(createTestHttpClient()) + .withUri(HOST) + .withDatabaseName(DB_NAME) + .withToken("not-a-valid-token") + .withModuleBindings() + .onConnect { _, _, _ -> error.completeExceptionally(AssertionError("Should not connect with invalid token")) } + .onConnectError { _, e -> error.complete(e) } + .build() + + val ex = withTimeout(DEFAULT_TIMEOUT_MS) { error.await() } + assertNotNull(ex, "Should receive an error on invalid token") + assertEquals(ex.message?.contains("401"), true, "Error should mention 401: ${ex.message}") + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/DbConnectionDisconnectTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/DbConnectionDisconnectTest.kt new file mode 100644 index 00000000000..07205918913 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/DbConnectionDisconnectTest.kt @@ -0,0 +1,92 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.reducers +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class DbConnectionDisconnectTest { + + @Test + fun `double disconnect does not throw`() = runBlocking { + val client = connectToDb() + assertTrue(client.conn.isActive) + + client.conn.disconnect() + assertFalse(client.conn.isActive) + + // Second disconnect should be a no-op + client.conn.disconnect() + assertFalse(client.conn.isActive) + } + + @Test + fun `disconnect fires onDisconnect callback`() = runBlocking { + val client = connectToDb() + val disconnected = CompletableDeferred() + + client.conn.onDisconnect { _, error -> + disconnected.complete(error) + } + + client.conn.disconnect() + + val error = withTimeout(DEFAULT_TIMEOUT_MS) { disconnected.await() } + assertEquals(error, null, "Clean disconnect should have null error, got: $error") + } + + @Test + fun `reducer call after disconnect does not crash`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + client.conn.disconnect() + assertFalse(client.conn.isActive) + + // Calling a reducer on a disconnected connection should not crash + try { + client.conn.reducers.sendMessage("should-not-arrive") + } catch (_: Exception) { + // Expected — some SDKs throw, some silently fail + } + } + + @Test + fun `suspend oneOffQuery after disconnect throws immediately`() = runBlocking { + // After disconnect the send channel is closed, so oneOffQuery throws + // IllegalStateException immediately rather than hanging. + val client = connectToDb() + client.conn.disconnect() + + var threw = false + try { + withTimeout(2000) { + client.conn.oneOffQuery("SELECT * FROM user") + } + } catch (_: TimeoutCancellationException) { + threw = true + } catch (_: Exception) { + threw = true + } + assertTrue(threw, "suspend oneOffQuery on disconnected conn should fail") + } + + @Test + fun `callback oneOffQuery after disconnect does not crash`() = runBlocking { + val client = connectToDb() + client.conn.disconnect() + + // Callback variant — just fires and forgets, callback never invoked + try { + client.conn.oneOffQuery("SELECT * FROM user") { _ -> } + } catch (_: Exception) { + // Expected + } + Unit + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/DbConnectionIsActiveTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/DbConnectionIsActiveTest.kt new file mode 100644 index 00000000000..2264908d6ad --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/DbConnectionIsActiveTest.kt @@ -0,0 +1,20 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import kotlinx.coroutines.runBlocking +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class DbConnectionIsActiveTest { + + @Test + fun `isActive reflects connection lifecycle`() = runBlocking { + val client = connectToDb() + + assertTrue(client.conn.isActive, "Should be active after connect") + + client.conn.disconnect() + + assertFalse(client.conn.isActive, "Should be inactive after disconnect") + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/DbConnectionUseTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/DbConnectionUseTest.kt new file mode 100644 index 00000000000..13358acdd53 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/DbConnectionUseTest.kt @@ -0,0 +1,95 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.use +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.cancel +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.withModuleBindings +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +class DbConnectionUseTest { + + private suspend fun buildConnectedDb(): DbConnection { + val connected = CompletableDeferred() + val conn = DbConnection.Builder() + .withHttpClient(createTestHttpClient()) + .withUri(HOST) + .withDatabaseName(DB_NAME) + .withModuleBindings() + .onConnect { _, _, _ -> connected.complete(Unit) } + .onConnectError { _, e -> connected.completeExceptionally(e) } + .build() + withTimeout(DEFAULT_TIMEOUT_MS) { connected.await() } + return conn + } + + @Test + fun `use block auto-disconnects after block completes`() = runBlocking { + val conn = buildConnectedDb() + assertTrue(conn.isActive, "Connection should be active before use{}") + + conn.use { + assertTrue(it.isActive, "Connection should be active inside use{}") + } + + assertFalse(conn.isActive, "Connection should be inactive after use{}") + } + + @Test + fun `use block disconnects even when exception is thrown`() = runBlocking { + val conn = buildConnectedDb() + assertTrue(conn.isActive) + + assertFailsWith { + conn.use { + throw IllegalStateException("test error inside use{}") + } + } + + assertFalse(conn.isActive, "Connection should be inactive after exception in use{}") + } + + @Test + fun `use block propagates return value`() = runBlocking { + val conn = buildConnectedDb() + + val result = conn.use { 42 } + + assertEquals(42, result, "use{} should propagate the return value") + assertFalse(conn.isActive) + } + + @Test + fun `use block disconnects on coroutine cancellation`() = runBlocking { + val conn = buildConnectedDb() + assertTrue(conn.isActive) + + try { + coroutineScope { + launch { + conn.use { + // Cancel the outer scope while inside use{} + this@coroutineScope.cancel("test cancellation") + // Suspend to let cancellation propagate + kotlinx.coroutines.delay(Long.MAX_VALUE) + } + } + } + } catch (_: CancellationException) { + // expected + } + + // Give NonCancellable disconnect a moment to complete + kotlinx.coroutines.delay(500) + assertFalse(conn.isActive, "Connection should be inactive after cancellation") + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/EventContextTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/EventContextTest.kt new file mode 100644 index 00000000000..4fa08be9565 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/EventContextTest.kt @@ -0,0 +1,182 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Status +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.db +import module_bindings.reducers +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertIs +import kotlin.test.assertTrue + +class EventContextTest { + + @Test + fun `reducer context has callerIdentity matching our identity`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val callerIdentityDeferred = CompletableDeferred() + client.conn.reducers.onSetName { c, _ -> + if (c.callerIdentity == client.identity) callerIdentityDeferred.complete(c.callerIdentity) + } + client.conn.reducers.setName("ctx-test-${System.nanoTime()}") + + val callerIdentity = withTimeout(DEFAULT_TIMEOUT_MS) { callerIdentityDeferred.await() } + assertEquals(client.identity, callerIdentity, "callerIdentity should match our identity") + + client.conn.disconnect() + } + + @Test + fun `reducer context has non-null callerConnectionId`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val connIdDeferred = CompletableDeferred() + client.conn.reducers.onSetName { c, _ -> + if (c.callerIdentity == client.identity) connIdDeferred.complete(c.callerConnectionId) + } + client.conn.reducers.setName("ctx-connid-${System.nanoTime()}") + + val connId = withTimeout(DEFAULT_TIMEOUT_MS) { connIdDeferred.await() } + assertNotNull(connId, "callerConnectionId should not be null for our own reducer call") + + client.conn.disconnect() + } + + @Test + fun `successful reducer has Status Committed`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val statusDeferred = CompletableDeferred() + client.conn.reducers.onSetName { c, _ -> + if (c.callerIdentity == client.identity) statusDeferred.complete(c.status) + } + client.conn.reducers.setName("status-ok-${System.nanoTime()}") + + val s = withTimeout(DEFAULT_TIMEOUT_MS) { statusDeferred.await() } + assertTrue(s is Status.Committed, "Successful reducer should have Status.Committed, got: $s") + + client.conn.disconnect() + } + + @Test + fun `failed reducer has Status Failed`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val statusDeferred = CompletableDeferred() + client.conn.reducers.onSetName { c, _ -> + if (c.callerIdentity == client.identity) statusDeferred.complete(c.status) + } + // Setting empty name should fail (server validates non-empty) + client.conn.reducers.setName("") + + val s = withTimeout(DEFAULT_TIMEOUT_MS) { statusDeferred.await() } + assertIs(s, "Empty name reducer should have Status.Failed, got: $s") + val failedMsg = s.message + assertTrue(failedMsg.isNotEmpty(), "Failed status should have a message: $failedMsg") + + client.conn.disconnect() + } + + @Test + fun `reducer context has reducerName`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val nameDeferred = CompletableDeferred() + client.conn.reducers.onSetName { c, _ -> + if (c.callerIdentity == client.identity) nameDeferred.complete(c.reducerName) + } + client.conn.reducers.setName("reducer-name-test-${System.nanoTime()}") + + val reducerName = withTimeout(DEFAULT_TIMEOUT_MS) { nameDeferred.await() } + assertEquals("set_name", reducerName, "reducerName should be 'set_name'") + + client.conn.disconnect() + } + + @Test + fun `reducer context has timestamp`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val tsDeferred = CompletableDeferred() + client.conn.reducers.onSetName { c, _ -> + if (c.callerIdentity == client.identity) tsDeferred.complete(c.timestamp) + } + client.conn.reducers.setName("ts-test-${System.nanoTime()}") + + val ts = withTimeout(DEFAULT_TIMEOUT_MS) { tsDeferred.await() } + assertNotNull(ts, "timestamp should not be null") + + client.conn.disconnect() + } + + @Test + fun `reducer context args contain the argument passed`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val uniqueName = "args-test-${System.nanoTime()}" + val argsDeferred = CompletableDeferred() + client.conn.reducers.onSetName { c, name -> + if (c.callerIdentity == client.identity && name == uniqueName) { + argsDeferred.complete(name) + } + } + client.conn.reducers.setName(uniqueName) + + val receivedName = withTimeout(DEFAULT_TIMEOUT_MS) { argsDeferred.await() } + assertEquals(uniqueName, receivedName, "Callback should receive the name argument") + + client.conn.disconnect() + } + + @Test + fun `onInsert receives SubscribeApplied context during initial subscription`() = runBlocking { + val client = connectToDb() + + val gotSubscribeApplied = CompletableDeferred() + client.conn.db.user.onInsert { ctx, _ -> + if (ctx is EventContext.SubscribeApplied) { + gotSubscribeApplied.complete(true) + } + } + + client.subscribeAll() + + val result = withTimeout(DEFAULT_TIMEOUT_MS) { gotSubscribeApplied.await() } + assertTrue(result, "onInsert during subscribe should receive SubscribeApplied context") + + client.conn.disconnect() + } + + @Test + fun `onInsert receives non-SubscribeApplied context for live inserts`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val ctxClass = CompletableDeferred() + client.conn.db.note.onInsert { c, note -> + if (c !is EventContext.SubscribeApplied && note.owner == client.identity && note.tag == "live-ctx") { + ctxClass.complete(c::class.simpleName ?: "unknown") + } + } + client.conn.reducers.addNote("live-context-test", "live-ctx") + + val className = withTimeout(DEFAULT_TIMEOUT_MS) { ctxClass.await() } + assertTrue(className != "SubscribeApplied", "Live insert should NOT be SubscribeApplied, got: $className") + + client.cleanup() + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/GeneratedTypeTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/GeneratedTypeTest.kt new file mode 100644 index 00000000000..d108da1e0ad --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/GeneratedTypeTest.kt @@ -0,0 +1,246 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ScheduleAt +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import module_bindings.Message +import module_bindings.Note +import module_bindings.Reminder +import module_bindings.User +import module_bindings.db +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotEquals +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.minutes + +/** + * Generated data class equality, hashCode, toString, and copy tests. + */ +class GeneratedTypeTest { + + private val identity1 = Identity.fromHexString("aa".repeat(32)) + private val identity2 = Identity.fromHexString("bb".repeat(32)) + private val ts = Timestamp.fromMillis(1700000000000L) + + // --- User equals/hashCode --- + + @Test + fun `User equals same values`() { + val a = User(identity1, "Alice", true) + val b = User(identity1, "Alice", true) + assertEquals(a, b) + } + + @Test + fun `User not equals different identity`() { + val a = User(identity1, "Alice", true) + val b = User(identity2, "Alice", true) + assertNotEquals(a, b) + } + + @Test + fun `User not equals different name`() { + val a = User(identity1, "Alice", true) + val b = User(identity1, "Bob", true) + assertNotEquals(a, b) + } + + @Test + fun `User not equals different online`() { + val a = User(identity1, "Alice", true) + val b = User(identity1, "Alice", false) + assertNotEquals(a, b) + } + + @Test + fun `User equals with null name`() { + val a = User(identity1, null, false) + val b = User(identity1, null, false) + assertEquals(a, b) + } + + @Test + fun `User not equals null vs non-null name`() { + val a = User(identity1, null, true) + val b = User(identity1, "Alice", true) + assertNotEquals(a, b) + } + + @Test + fun `User hashCode consistent with equals`() { + val a = User(identity1, "Alice", true) + val b = User(identity1, "Alice", true) + assertEquals(a.hashCode(), b.hashCode()) + } + + @Test + fun `User hashCode differs for different values`() { + val a = User(identity1, "Alice", true) + val b = User(identity2, "Bob", false) + assertNotEquals(a.hashCode(), b.hashCode()) + } + + // --- User toString --- + + @Test + fun `User toString contains field values`() { + val user = User(identity1, "Alice", true) + val str = user.toString() + assertTrue(str.contains("Alice"), "toString should contain name: $str") + assertTrue(str.contains("true"), "toString should contain online: $str") + assertTrue(str.contains("User"), "toString should contain class name: $str") + } + + @Test + fun `User toString with null name`() { + val user = User(identity1, null, false) + val str = user.toString() + assertTrue(str.contains("null"), "toString should show null for name: $str") + } + + // --- Message equals/hashCode --- + + @Test + fun `Message equals same values`() { + val a = Message(1UL, identity1, ts, "hello") + val b = Message(1UL, identity1, ts, "hello") + assertEquals(a, b) + } + + @Test + fun `Message not equals different id`() { + val a = Message(1UL, identity1, ts, "hello") + val b = Message(2UL, identity1, ts, "hello") + assertNotEquals(a, b) + } + + @Test + fun `Message not equals different text`() { + val a = Message(1UL, identity1, ts, "hello") + val b = Message(1UL, identity1, ts, "world") + assertNotEquals(a, b) + } + + @Test + fun `Message toString contains field values`() { + val msg = Message(42UL, identity1, ts, "test message") + val str = msg.toString() + assertTrue(str.contains("42"), "toString should contain id: $str") + assertTrue(str.contains("test message"), "toString should contain text: $str") + assertTrue(str.contains("Message"), "toString should contain class name: $str") + } + + // --- Note equals/hashCode --- + + @Test + fun `Note equals same values`() { + val a = Note(1UL, identity1, "content", "tag") + val b = Note(1UL, identity1, "content", "tag") + assertEquals(a, b) + } + + @Test + fun `Note not equals different tag`() { + val a = Note(1UL, identity1, "content", "tag1") + val b = Note(1UL, identity1, "content", "tag2") + assertNotEquals(a, b) + } + + @Test + fun `Note hashCode consistent with equals`() { + val a = Note(5UL, identity1, "x", "y") + val b = Note(5UL, identity1, "x", "y") + assertEquals(a.hashCode(), b.hashCode()) + } + + // --- Reminder equals/hashCode --- + + @Test + fun `Reminder equals same values`() { + val sa = ScheduleAt.interval(5.minutes) + val a = Reminder(1UL, sa, "remind me", identity1) + val b = Reminder(1UL, sa, "remind me", identity1) + assertEquals(a, b) + } + + @Test + fun `Reminder not equals different text`() { + val sa = ScheduleAt.interval(5.minutes) + val a = Reminder(1UL, sa, "first", identity1) + val b = Reminder(1UL, sa, "second", identity1) + assertNotEquals(a, b) + } + + @Test + fun `Reminder toString contains field values`() { + val sa = ScheduleAt.interval(5.minutes) + val r = Reminder(99UL, sa, "reminder text", identity1) + val str = r.toString() + assertTrue(str.contains("99"), "toString should contain scheduledId: $str") + assertTrue(str.contains("reminder text"), "toString should contain text: $str") + assertTrue(str.contains("Reminder"), "toString should contain class name: $str") + } + + // --- Copy (Kotlin data class feature) --- + + @Test + fun `User copy preserves unchanged fields`() { + val original = User(identity1, "Alice", true) + val copy = original.copy(name = "Bob") + assertEquals(identity1, copy.identity) + assertEquals("Bob", copy.name) + assertEquals(true, copy.online) + } + + @Test + fun `Message copy with different id`() { + val original = Message(1UL, identity1, ts, "hello") + val copy = original.copy(id = 99UL) + assertEquals(99UL, copy.id) + assertEquals(identity1, copy.sender) + assertEquals("hello", copy.text) + } + + // --- Destructuring (Kotlin data class feature) --- + + @Test + fun `User destructuring`() { + val user = User(identity1, "Alice", true) + val (identity, name, online) = user + assertEquals(identity1, identity) + assertEquals("Alice", name) + assertEquals(true, online) + } + + @Test + fun `Note destructuring`() { + val note = Note(7UL, identity1, "content", "tag") + val (id, owner, content, tag) = note + assertEquals(7UL, id) + assertEquals(identity1, owner) + assertEquals("content", content) + assertEquals("tag", tag) + } + + // --- Live roundtrip through server --- + + @Test + fun `User from server has correct data class behavior`() = kotlinx.coroutines.runBlocking { + val client = connectToDb() + client.subscribeAll() + + val user = client.conn.db.user.identity.find(client.identity)!! + + // data class equals works with server-returned instances + val userCopy = user.copy() + assertEquals(user, userCopy) + assertEquals(user.hashCode(), userCopy.hashCode()) + + // toString is meaningful + val str = user.toString() + assertTrue(str.contains("User"), "Server user toString: $str") + + client.conn.disconnect() + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/IdentityTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/IdentityTest.kt new file mode 100644 index 00000000000..0aa87819107 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/IdentityTest.kt @@ -0,0 +1,136 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotEquals +import kotlin.test.assertTrue + +class IdentityTest { + + // --- Factories --- + + @Test + fun `zero creates zero identity`() { + val id = Identity.zero() + assertEquals("0".repeat(64), id.toHexString(), "Zero identity should be 64 zeros") + } + + @Test + fun `fromHexString parses valid hex`() { + val hex = "ab".repeat(32) // 64 hex chars = 32 bytes = U256 + val id = Identity.fromHexString(hex) + assertTrue(id.toHexString().contains("ab"), "Should contain ab: ${id.toHexString()}") + } + + @Test + fun `fromHexString roundtrips`() { + val hex = "0123456789abcdef".repeat(4) // 64 hex chars + val id = Identity.fromHexString(hex) + assertEquals(hex, id.toHexString()) + } + + @Test + fun `fromHexString rejects invalid hex`() { + assertFailsWith { + Identity.fromHexString("not-valid-hex") + } + } + + // --- Conversions --- + + @Test + fun `toHexString returns 64 lowercase hex chars`() { + val hex = "ab".repeat(32) + val id = Identity.fromHexString(hex) + val result = id.toHexString() + assertEquals(64, result.length, "Hex should be 64 chars: $result") + assertTrue(result.all { it in '0'..'9' || it in 'a'..'f' }, "Should be lowercase hex: $result") + } + + @Test + fun `toByteArray returns 32 bytes`() { + val id = Identity.zero() + val bytes = id.toByteArray() + assertEquals(32, bytes.size, "Identity should be 32 bytes") + } + + @Test + fun `zero toByteArray is all zeros`() { + val bytes = Identity.zero().toByteArray() + assertTrue(bytes.all { it == 0.toByte() }, "Zero identity bytes should all be 0") + } + + @Test + fun `toString returns hex string`() { + val id = Identity.zero() + assertEquals(id.toHexString(), id.toString()) + } + + // --- Comparison --- + + @Test + fun `compareTo zero vs nonzero`() { + val zero = Identity.zero() + val nonzero = Identity.fromHexString("00".repeat(31) + "01") + assertTrue(zero < nonzero, "Zero should be less than nonzero") + } + + @Test + fun `compareTo equal identities`() { + val a = Identity.fromHexString("ab".repeat(32)) + val b = Identity.fromHexString("ab".repeat(32)) + assertEquals(0, a.compareTo(b)) + } + + @Test + fun `compareTo is reflexive`() { + val id = Identity.fromHexString("cd".repeat(32)) + assertEquals(0, id.compareTo(id)) + } + + // --- equals / hashCode --- + + @Test + fun `equal identities have same hashCode`() { + val a = Identity.fromHexString("ab".repeat(32)) + val b = Identity.fromHexString("ab".repeat(32)) + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + } + + @Test + fun `different identities are not equal`() { + val a = Identity.fromHexString("ab".repeat(32)) + val b = Identity.fromHexString("cd".repeat(32)) + assertNotEquals(a, b) + } + + // --- Live identity from connection --- + + @Test + fun `identity from connection has valid hex string`() = kotlinx.coroutines.runBlocking { + val client = connectToDb() + val hex = client.identity.toHexString() + assertEquals(64, hex.length, "Live identity hex should be 64 chars") + assertTrue(hex.all { it in '0'..'9' || it in 'a'..'f' }, "Should be valid hex: $hex") + client.conn.disconnect() + } + + @Test + fun `identity from connection has 32-byte array`() = kotlinx.coroutines.runBlocking { + val client = connectToDb() + assertEquals(32, client.identity.toByteArray().size) + client.conn.disconnect() + } + + @Test + fun `identity fromHexString roundtrips with live identity`() = kotlinx.coroutines.runBlocking { + val client = connectToDb() + val hex = client.identity.toHexString() + val parsed = Identity.fromHexString(hex) + assertEquals(client.identity, parsed) + client.conn.disconnect() + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/JoinTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/JoinTest.kt new file mode 100644 index 00000000000..864353f63a7 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/JoinTest.kt @@ -0,0 +1,80 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SqlLit +import kotlinx.coroutines.runBlocking +import module_bindings.QueryBuilder +import kotlin.test.Test +import kotlin.test.assertTrue + +class JoinTest { + + @Test + fun `leftSemijoin generates valid SQL`() = runBlocking { + val qb = QueryBuilder() + // note.id JOIN message.id (both IxCol<*, ULong>) — synthetic but tests the API + val query = qb.note().leftSemijoin(qb.message()) { left, right -> + left.id.eq(right.id) + } + val sql = query.toSql() + assertTrue(sql.contains("JOIN"), "Should contain JOIN: $sql") + assertTrue(sql.contains("\"note\".*"), "Should select note.*: $sql") + assertTrue(sql.contains("\"note\".\"id\" = \"message\".\"id\""), "Should have ON clause: $sql") + } + + @Test + fun `rightSemijoin generates valid SQL`() = runBlocking { + val qb = QueryBuilder() + val query = qb.note().rightSemijoin(qb.message()) { left, right -> + left.id.eq(right.id) + } + val sql = query.toSql() + assertTrue(sql.contains("JOIN"), "Should contain JOIN: $sql") + assertTrue(sql.contains("\"message\".*"), "Should select message.*: $sql") + } + + @Test + fun `leftSemijoin with where clause`() = runBlocking { + val qb = QueryBuilder() + val query = qb.note() + .leftSemijoin(qb.message()) { left, right -> left.id.eq(right.id) } + .where { c -> c.tag.eq(SqlLit.string("test")) } + val sql = query.toSql() + assertTrue(sql.contains("JOIN"), "Should contain JOIN: $sql") + assertTrue(sql.contains("WHERE"), "Should contain WHERE: $sql") + } + + @Test + fun `rightSemijoin with where clause`() = runBlocking { + val qb = QueryBuilder() + val query = qb.note() + .rightSemijoin(qb.message()) { left, right -> left.id.eq(right.id) } + .where { c -> c.text.eq(SqlLit.string("hello")) } + val sql = query.toSql() + assertTrue(sql.contains("JOIN"), "Should contain JOIN: $sql") + assertTrue(sql.contains("WHERE"), "Should contain WHERE: $sql") + } + + @Test + fun `IxCol eq produces IxJoinEq for join condition`() = runBlocking { + val qb = QueryBuilder() + // Verify the IxCol.eq(otherIxCol) produces the correct ON clause + val query = qb.note().leftSemijoin(qb.message()) { left, right -> + left.id.eq(right.id) + } + val sql = query.toSql() + // The ON clause should reference both table columns + assertTrue(sql.contains("\"note\".\"id\""), "Should reference note.id: $sql") + assertTrue(sql.contains("\"message\".\"id\""), "Should reference message.id: $sql") + } + + @Test + fun `FromWhere leftSemijoin chains where then join`() = runBlocking { + val qb = QueryBuilder() + val query = qb.note() + .where { c -> c.tag.eq(SqlLit.string("important")) } + .leftSemijoin(qb.message()) { left, right -> left.id.eq(right.id) } + val sql = query.toSql() + assertTrue(sql.contains("JOIN"), "Should contain JOIN: $sql") + assertTrue(sql.contains("WHERE"), "Should contain WHERE from pre-join filter: $sql") + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/LightModeTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/LightModeTest.kt new file mode 100644 index 00000000000..c8f2dfbb7a1 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/LightModeTest.kt @@ -0,0 +1,80 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.db +import module_bindings.reducers +import module_bindings.withModuleBindings +import kotlin.test.Test +import kotlin.test.assertEquals + +/** + * Verifies that light mode connections work correctly. + * Light mode skips sending initial subscription rows — the client + * can still call reducers and receive subsequent table updates. + */ +class LightModeTest { + + private suspend fun connectLightMode(): ConnectedClient { + val identityDeferred = CompletableDeferred>() + + val conn = DbConnection.Builder() + .withHttpClient(createTestHttpClient()) + .withUri(HOST) + .withDatabaseName(DB_NAME) + .withLightMode(true) + .withModuleBindings() + .onConnect { _, identity, tok -> + identityDeferred.complete(identity to tok) + } + .onConnectError { _, e -> + identityDeferred.completeExceptionally(e) + } + .build() + + val (identity, tok) = withTimeout(DEFAULT_TIMEOUT_MS) { identityDeferred.await() } + return ConnectedClient(conn = conn, identity = identity, token = tok) + } + + @Test + fun `connect in light mode and call reducer`() = runBlocking { + val client = connectLightMode() + + // Subscribe — in light mode, initial rows are skipped + val applied = CompletableDeferred() + client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .subscribe(listOf("SELECT * FROM message")) + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + + // Send a message and verify we receive the insert callback + val text = "light-mode-${System.nanoTime()}" + val received = CompletableDeferred() + client.conn.db.message.onInsert { _, row -> + if (row.text == text) received.complete(row.text) + } + client.conn.reducers.sendMessage(text) + + assertEquals(text, withTimeout(DEFAULT_TIMEOUT_MS) { received.await() }) + client.conn.disconnect() + } + + @Test + fun `light mode subscription starts with empty cache`() = runBlocking { + val client = connectLightMode() + + val applied = CompletableDeferred() + client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .subscribe(listOf("SELECT * FROM note")) + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + + // In light mode, the cache should be empty after subscription + // (no initial rows sent by server) + assertEquals(client.conn.db.note.count(), 0, "Light mode should not receive initial rows") + client.conn.disconnect() + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/LoggerTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/LoggerTest.kt new file mode 100644 index 00000000000..8ba91a09077 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/LoggerTest.kt @@ -0,0 +1,134 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.LogLevel +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Logger +import kotlin.test.AfterTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class LoggerTest { + + private val originalLevel = Logger.level + private val originalHandler = Logger.handler + + @AfterTest + fun restore() { + Logger.level = originalLevel + Logger.handler = originalHandler + } + + @Test + fun `level can be get and set`() { + Logger.level = LogLevel.DEBUG + assertEquals(LogLevel.DEBUG, Logger.level) + + Logger.level = LogLevel.ERROR + assertEquals(LogLevel.ERROR, Logger.level) + } + + @Test + fun `custom handler receives log messages`() { + val logs = mutableListOf>() + Logger.handler = { level, message -> logs.add(level to message) } + Logger.level = LogLevel.TRACE + + Logger.info { "test-info-message" } + Logger.warn { "test-warn-message" } + Logger.debug { "test-debug-message" } + + assertTrue(logs.any { it.first == LogLevel.INFO && it.second.contains("test-info-message") }) + assertTrue(logs.any { it.first == LogLevel.WARN && it.second.contains("test-warn-message") }) + assertTrue(logs.any { it.first == LogLevel.DEBUG && it.second.contains("test-debug-message") }) + } + + @Test + fun `level filters messages below threshold`() { + val logs = mutableListOf>() + Logger.handler = { level, message -> logs.add(level to message) } + Logger.level = LogLevel.WARN + + Logger.info { "should-be-filtered" } + Logger.debug { "should-be-filtered" } + Logger.trace { "should-be-filtered" } + Logger.warn { "should-appear" } + Logger.error { "should-appear" } + + assertEquals(2, logs.size, "Only WARN and ERROR should pass, got: $logs") + assertTrue(logs.all { it.first == LogLevel.WARN || it.first == LogLevel.ERROR }) + } + + @Test + fun `trace messages pass at TRACE level`() { + val logs = mutableListOf>() + Logger.handler = { level, message -> logs.add(level to message) } + Logger.level = LogLevel.TRACE + + Logger.trace { "trace-message" } + assertTrue(logs.any { it.first == LogLevel.TRACE && it.second.contains("trace-message") }) + } + + @Test + fun `exception with Throwable logs stack trace`() { + val logs = mutableListOf>() + Logger.handler = { level, message -> logs.add(level to message) } + Logger.level = LogLevel.TRACE + + val ex = RuntimeException("test-exception-message") + Logger.exception(ex) + + assertTrue(logs.any { it.first == LogLevel.EXCEPTION }, "Should log at EXCEPTION level") + assertTrue( + logs.any { it.second.contains("test-exception-message") }, + "Should contain exception message in stack trace" + ) + } + + @Test + fun `exception with lambda logs message`() { + val logs = mutableListOf>() + Logger.handler = { level, message -> logs.add(level to message) } + Logger.level = LogLevel.TRACE + + Logger.exception { "exception-lambda-message" } + + assertTrue(logs.any { it.first == LogLevel.EXCEPTION && it.second.contains("exception-lambda-message") }) + } + + @Test + fun `sensitive data is redacted`() { + val logs = mutableListOf>() + Logger.handler = { level, message -> logs.add(level to message) } + Logger.level = LogLevel.TRACE + + Logger.info { "token=my-secret-token-123" } + + val message = logs.first().second + assertTrue(message.contains("[REDACTED]"), "Token value should be redacted: $message") + assertTrue(!message.contains("my-secret-token-123"), "Actual token should not appear: $message") + } + + @Test + fun `sensitive data redaction covers multiple patterns`() { + val logs = mutableListOf>() + Logger.handler = { level, message -> logs.add(level to message) } + Logger.level = LogLevel.TRACE + + Logger.info { "password=hunter2 secret=abc123" } + + val message = logs.first().second + assertTrue(!message.contains("hunter2"), "Password should be redacted: $message") + assertTrue(!message.contains("abc123"), "Secret should be redacted: $message") + } + + @Test + fun `lazy message is not evaluated when level is filtered`() { + Logger.handler = { _, _ -> } + Logger.level = LogLevel.ERROR + + var evaluated = false + Logger.debug { evaluated = true; "should-not-evaluate" } + + assertTrue(!evaluated, "Debug message lambda should not be evaluated at ERROR level") + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/MultiClientTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/MultiClientTest.kt new file mode 100644 index 00000000000..01baf8788c6 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/MultiClientTest.kt @@ -0,0 +1,354 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.db +import module_bindings.reducers +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotEquals +import kotlin.test.assertNotNull +import kotlin.test.assertIs +import kotlin.test.assertTrue + +class MultiClientTest { + + private suspend fun connectTwo(): Pair { + val a = connectToDb().subscribeAll() + val b = connectToDb().subscribeAll() + return a to b + } + + private suspend fun cleanupBoth(a: ConnectedClient, b: ConnectedClient) { + a.cleanup() + b.cleanup() + } + + // ── Message propagation ── + + @Test + fun `client B sees message sent by client A via onInsert`() = runBlocking { + val (a, b) = connectTwo() + + val tag = "multi-msg-${System.nanoTime()}" + val seen = CompletableDeferred() + b.conn.db.message.onInsert { ctx, msg -> + if (ctx !is EventContext.SubscribeApplied && msg.text == tag) { + seen.complete(msg) + } + } + + a.conn.reducers.sendMessage(tag) + val msg = withTimeout(DEFAULT_TIMEOUT_MS) { seen.await() } + + assertEquals(tag, msg.text) + assertEquals(a.identity, msg.sender, "Sender should be client A's identity") + + cleanupBoth(a, b) + } + + @Test + fun `client B cache contains message after client A sends it`() = runBlocking { + val (a, b) = connectTwo() + + val tag = "multi-cache-${System.nanoTime()}" + val inserted = CompletableDeferred() + b.conn.db.message.onInsert { ctx, msg -> + if (ctx !is EventContext.SubscribeApplied && msg.text == tag) { + inserted.complete(msg.id) + } + } + + a.conn.reducers.sendMessage(tag) + val msgId = withTimeout(DEFAULT_TIMEOUT_MS) { inserted.await() } + + val cached = b.conn.db.message.id.find(msgId) + assertNotNull(cached, "Client B cache should contain the message") + assertEquals(tag, cached.text) + + cleanupBoth(a, b) + } + + @Test + fun `client B sees message deleted by client A via onDelete`() = runBlocking { + val (a, b) = connectTwo() + + val tag = "multi-del-${System.nanoTime()}" + + // A sends a message, wait for B to see it + val insertSeen = CompletableDeferred() + b.conn.db.message.onInsert { ctx, msg -> + if (ctx !is EventContext.SubscribeApplied && msg.text == tag) { + insertSeen.complete(msg.id) + } + } + a.conn.reducers.sendMessage(tag) + val msgId = withTimeout(DEFAULT_TIMEOUT_MS) { insertSeen.await() } + + // B listens for deletion, A deletes + val deleteSeen = CompletableDeferred() + b.conn.db.message.onDelete { ctx, msg -> + if (ctx !is EventContext.SubscribeApplied && msg.id == msgId) { + deleteSeen.complete(msg.id) + } + } + a.conn.reducers.deleteMessage(msgId) + + val deletedId = withTimeout(DEFAULT_TIMEOUT_MS) { deleteSeen.await() } + assertEquals(msgId, deletedId) + assertEquals(null, b.conn.db.message.id.find(msgId), "Message should be gone from B's cache") + + cleanupBoth(a, b) + } + + // ── User table propagation ── + + @Test + fun `client B sees client A set name via onUpdate`() = runBlocking { + val (a, b) = connectTwo() + + val newName = "multi-name-${System.nanoTime()}" + val updateSeen = CompletableDeferred>() + b.conn.db.user.onUpdate { _, old, new -> + if (new.identity == a.identity && new.name == newName) { + updateSeen.complete(old to new) + } + } + + a.conn.reducers.setName(newName) + val (old, new) = withTimeout(DEFAULT_TIMEOUT_MS) { updateSeen.await() } + + assertNotEquals(newName, old.name, "Old name should differ from the new name") + assertEquals(newName, new.name) + assertEquals(a.identity, new.identity) + + cleanupBoth(a, b) + } + + @Test + fun `client B sees client A come online via user table`() = runBlocking { + val b = connectToDb().subscribeAll() + + // B listens for a new user insert + val userSeen = CompletableDeferred() + b.conn.db.user.onInsert { ctx, user -> + if (ctx !is EventContext.SubscribeApplied && user.online) { + userSeen.complete(user) + } + } + + val a = connectToDb().subscribeAll() + + val newUser = withTimeout(DEFAULT_TIMEOUT_MS) { userSeen.await() } + assertEquals(a.identity, newUser.identity) + assertTrue(newUser.online) + + cleanupBoth(a, b) + } + + @Test + fun `client B sees client A go offline via onUpdate`() = runBlocking { + val (a, b) = connectTwo() + + val offlineSeen = CompletableDeferred() + b.conn.db.user.onUpdate { _, old, new -> + if (new.identity == a.identity && old.online && !new.online) { + offlineSeen.complete(new) + } + } + + a.conn.disconnect() + + val offlineUser = withTimeout(DEFAULT_TIMEOUT_MS) { offlineSeen.await() } + assertEquals(a.identity, offlineUser.identity) + assertFalse(offlineUser.online) + + // Only cleanup B (A already disconnected) + b.cleanup() + } + + // ── Note propagation ── + + @Test + fun `client B sees note added by client A`() = runBlocking { + val (a, b) = connectTwo() + + val tag = "multi-note-${System.nanoTime()}" + val noteSeen = CompletableDeferred() + b.conn.db.note.onInsert { ctx, note -> + if (ctx !is EventContext.SubscribeApplied && note.tag == tag) { + noteSeen.complete(note) + } + } + + a.conn.reducers.addNote("content from A", tag) + val note = withTimeout(DEFAULT_TIMEOUT_MS) { noteSeen.await() } + + assertEquals(a.identity, note.owner) + assertEquals("content from A", note.content) + assertEquals(tag, note.tag) + + cleanupBoth(a, b) + } + + @Test + fun `client B sees note deleted by client A`() = runBlocking { + val (a, b) = connectTwo() + + val tag = "multi-notedel-${System.nanoTime()}" + + // A adds note, B waits for it + val insertSeen = CompletableDeferred() + b.conn.db.note.onInsert { ctx, note -> + if (ctx !is EventContext.SubscribeApplied && note.tag == tag) { + insertSeen.complete(note.id) + } + } + a.conn.reducers.addNote("to-delete", tag) + val noteId = withTimeout(DEFAULT_TIMEOUT_MS) { insertSeen.await() } + + // B listens for deletion, A deletes + val deleteSeen = CompletableDeferred() + b.conn.db.note.onDelete { ctx, note -> + if (ctx !is EventContext.SubscribeApplied && note.id == noteId) { + deleteSeen.complete(note.id) + } + } + a.conn.reducers.deleteNote(noteId) + + val deletedId = withTimeout(DEFAULT_TIMEOUT_MS) { deleteSeen.await() } + assertEquals(noteId, deletedId) + + cleanupBoth(a, b) + } + + // ── EventContext cross-client ── + + @Test + fun `client A onInsert context is Reducer for own call`() = runBlocking { + val (a, b) = connectTwo() + + val tag = "multi-ctx-own-${System.nanoTime()}" + val ctxSeen = CompletableDeferred() + a.conn.db.message.onInsert { ctx, msg -> + if (ctx !is EventContext.SubscribeApplied && msg.text == tag) { + ctxSeen.complete(ctx) + } + } + + a.conn.reducers.sendMessage(tag) + val ctx = withTimeout(DEFAULT_TIMEOUT_MS) { ctxSeen.await() } + assertIs>(ctx, "Own reducer should produce Reducer context, got: ${ctx::class.simpleName}") + assertEquals(a.identity, ctx.callerIdentity) + + cleanupBoth(a, b) + } + + @Test + fun `client B onInsert context is Transaction for other client's call`() = runBlocking { + val (a, b) = connectTwo() + + val tag = "multi-ctx-other-${System.nanoTime()}" + val ctxSeen = CompletableDeferred() + b.conn.db.message.onInsert { ctx, msg -> + if (ctx !is EventContext.SubscribeApplied && msg.text == tag) { + ctxSeen.complete(ctx) + } + } + + a.conn.reducers.sendMessage(tag) + val ctx = withTimeout(DEFAULT_TIMEOUT_MS) { ctxSeen.await() } + assertTrue( + ctx is EventContext.Transaction, + "Cross-client reducer should produce Transaction context, got: ${ctx::class.simpleName}" + ) + + cleanupBoth(a, b) + } + + // ── Concurrent operations ── + + @Test + fun `both clients send messages and both see all messages`() = runBlocking { + val (a, b) = connectTwo() + + val tagA = "multi-both-a-${System.nanoTime()}" + val tagB = "multi-both-b-${System.nanoTime()}" + + // A waits to see B's message, B waits to see A's message + val aSeesB = CompletableDeferred() + val bSeesA = CompletableDeferred() + + a.conn.db.message.onInsert { ctx, msg -> + if (ctx !is EventContext.SubscribeApplied && msg.text == tagB) { + aSeesB.complete(msg) + } + } + b.conn.db.message.onInsert { ctx, msg -> + if (ctx !is EventContext.SubscribeApplied && msg.text == tagA) { + bSeesA.complete(msg) + } + } + + // Both send simultaneously + a.conn.reducers.sendMessage(tagA) + b.conn.reducers.sendMessage(tagB) + + val msgFromB = withTimeout(DEFAULT_TIMEOUT_MS) { aSeesB.await() } + val msgFromA = withTimeout(DEFAULT_TIMEOUT_MS) { bSeesA.await() } + + assertEquals(tagB, msgFromB.text) + assertEquals(b.identity, msgFromB.sender) + assertEquals(tagA, msgFromA.text) + assertEquals(a.identity, msgFromA.sender) + + cleanupBoth(a, b) + } + + @Test + fun `client B count updates after client A inserts`() = runBlocking { + val (a, b) = connectTwo() + + val beforeCount = b.conn.db.note.count() + + val tag = "multi-count-${System.nanoTime()}" + val insertSeen = CompletableDeferred() + b.conn.db.note.onInsert { ctx, note -> + if (ctx !is EventContext.SubscribeApplied && note.tag == tag) { + insertSeen.complete(Unit) + } + } + + a.conn.reducers.addNote("count-test", tag) + withTimeout(DEFAULT_TIMEOUT_MS) { insertSeen.await() } + + assertEquals(beforeCount + 1, b.conn.db.note.count(), "B's cache count should increment") + + cleanupBoth(a, b) + } + + // ── Identity isolation ── + + @Test + fun `two anonymous clients have different identities`() = runBlocking { + val (a, b) = connectTwo() + + assertNotEquals(a.identity, b.identity, "Two anonymous clients should have different identities") + + cleanupBoth(a, b) + } + + @Test + fun `client B can look up client A by identity in user table`() = runBlocking { + val (a, b) = connectTwo() + + val userA = b.conn.db.user.identity.find(a.identity) + assertNotNull(userA, "Client B should find client A in user table") + assertTrue(userA.online, "Client A should be online") + + cleanupBoth(a, b) + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/OneOffQueryTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/OneOffQueryTest.kt new file mode 100644 index 00000000000..7268e098a46 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/OneOffQueryTest.kt @@ -0,0 +1,118 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.OneOffQueryData +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.OneOffQueryResult +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.QueryError +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SdkResult +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.getOrNull +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import kotlin.test.Test +import kotlin.test.assertIs +import kotlin.test.assertTrue + +class OneOffQueryTest { + + @Test + fun `callback oneOffQuery with valid SQL returns Success`() = runBlocking { + val client = connectToDb() + + val result = CompletableDeferred() + client.conn.oneOffQuery("SELECT * FROM user") { msg -> + result.complete(msg) + } + + val qr = withTimeout(DEFAULT_TIMEOUT_MS) { result.await() } + assertIs>(qr, "Valid SQL should return Success, got: $qr") + + client.conn.disconnect() + } + + @Test + fun `callback oneOffQuery with invalid SQL returns Failure`() = runBlocking { + val client = connectToDb() + + val result = CompletableDeferred() + client.conn.oneOffQuery("THIS IS NOT VALID SQL AT ALL") { msg -> + result.complete(msg) + } + + val qr = withTimeout(DEFAULT_TIMEOUT_MS) { result.await() } + assertIs>(qr, "Invalid SQL should return Failure, got: $qr") + val serverError = assertIs(qr.error, "Error should be QueryError.ServerError") + assertTrue(serverError.message.isNotEmpty(), "Error message should be non-empty") + + client.conn.disconnect() + } + + @Test + fun `suspend oneOffQuery with valid SQL returns Success`() = runBlocking { + val client = connectToDb() + + val qr = withTimeout(DEFAULT_TIMEOUT_MS) { + client.conn.oneOffQuery("SELECT * FROM user") + } + assertIs>(qr, "Valid SQL should return Success, got: $qr") + assertTrue(qr.getOrNull()!!.tableCount >= 0, "tableCount should be non-negative") + + client.conn.disconnect() + } + + @Test + fun `suspend oneOffQuery with invalid SQL returns Failure`() = runBlocking { + val client = connectToDb() + + val qr = withTimeout(DEFAULT_TIMEOUT_MS) { + client.conn.oneOffQuery("INVALID SQL QUERY") + } + assertIs>(qr, "Invalid SQL should return Failure, got: $qr") + + client.conn.disconnect() + } + + @Test + fun `oneOffQuery returns Success with tableCount for populated table`() = runBlocking { + val client = connectToDb() + + val qr = withTimeout(DEFAULT_TIMEOUT_MS) { + client.conn.oneOffQuery("SELECT * FROM user") + } + assertIs>(qr, "Should return Success") + assertTrue(qr.getOrNull()!!.tableCount > 0, "Should have at least 1 table in result") + + client.conn.disconnect() + } + + @Test + fun `oneOffQuery returns Success for nonexistent filter`() = runBlocking { + val client = connectToDb() + + val qr = withTimeout(DEFAULT_TIMEOUT_MS) { + client.conn.oneOffQuery("SELECT * FROM note WHERE tag = 'nonexistent-tag-xyz-12345'") + } + assertIs>(qr, "Valid SQL should return Success even with 0 rows") + + client.conn.disconnect() + } + + @Test + fun `multiple concurrent oneOffQueries all return`() = runBlocking { + val client = connectToDb() + + val results = (1..5).map { _ -> + val deferred = CompletableDeferred() + client.conn.oneOffQuery("SELECT * FROM user") { msg -> + deferred.complete(msg) + } + deferred + } + + results.forEachIndexed { i, deferred -> + val qr = withTimeout(DEFAULT_TIMEOUT_MS) { deferred.await() } + assertIs>(qr, "Query $i should return Success, got: $qr") + } + + client.conn.disconnect() + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/QueryBuilderEdgeCaseTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/QueryBuilderEdgeCaseTest.kt new file mode 100644 index 00000000000..6567f6c9765 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/QueryBuilderEdgeCaseTest.kt @@ -0,0 +1,265 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SqlLit +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import module_bindings.QueryBuilder +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +/** + * Query builder SQL generation edge cases not already in + * TypeSafeQueryTest, ColComparisonTest, JoinTest. + */ +class QueryBuilderEdgeCaseTest { + + // --- NOT expression --- + + @Test + fun `NOT wraps expression in parentheses`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> + c.id.gt(SqlLit.ulong(18UL)).not() + }.toSql() + assertTrue(sql.contains("NOT"), "Should contain NOT: $sql") + assertTrue(sql.contains("(NOT"), "NOT should be parenthesized: $sql") + } + + // --- NOT with AND --- + + @Test + fun `NOT combined with AND`() { + val qb = QueryBuilder() + val sql = qb.user().where { c -> + c.online.eq(SqlLit.bool(true)).not() + .and(c.name.eq(SqlLit.string("admin"))) + }.toSql() + assertTrue(sql.contains("NOT"), "Should contain NOT: $sql") + assertTrue(sql.contains("AND"), "Should contain AND: $sql") + } + + // --- Method-style .and() / .or() chaining --- + + @Test + fun `method-style and chaining`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> + c.id.gt(SqlLit.ulong(20UL)) + .and(c.id.lt(SqlLit.ulong(30UL))) + }.toSql() + assertTrue(sql.contains("> 20"), "Should contain > 20: $sql") + assertTrue(sql.contains("AND"), "Should contain AND: $sql") + assertTrue(sql.contains("< 30"), "Should contain < 30: $sql") + } + + @Test + fun `method-style or chaining`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> + c.tag.eq(SqlLit.string("work")) + .or(c.tag.eq(SqlLit.string("personal"))) + }.toSql() + assertTrue(sql.contains("OR"), "Should contain OR: $sql") + assertTrue(sql.contains("'work'"), "Should contain 'work': $sql") + assertTrue(sql.contains("'personal'"), "Should contain 'personal': $sql") + } + + @Test + fun `nested and-or-not produces correct structure`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> + c.tag.eq(SqlLit.string("a")) + .and(c.content.eq(SqlLit.string("b")).or(c.content.eq(SqlLit.string("c")))) + }.toSql() + assertTrue(sql.contains("AND"), "Should contain AND: $sql") + assertTrue(sql.contains("OR"), "Should contain OR: $sql") + } + + // --- String escaping in WHERE --- + + @Test + fun `string with single quotes is escaped in WHERE`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> + c.content.eq(SqlLit.string("O'Reilly")) + }.toSql() + assertTrue(sql.contains("O''Reilly"), "Single quote should be escaped: $sql") + } + + @Test + fun `string with multiple single quotes`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> + c.content.eq(SqlLit.string("it's Bob's")) + }.toSql() + assertTrue(sql.contains("it''s Bob''s"), "All single quotes escaped: $sql") + } + + // --- Bool formatting --- + + @Test + fun `bool true formats as TRUE`() { + val qb = QueryBuilder() + val sql = qb.user().where { c -> + c.online.eq(SqlLit.bool(true)) + }.toSql() + assertTrue(sql.contains("TRUE"), "Should contain TRUE: $sql") + } + + @Test + fun `bool false formats as FALSE`() { + val qb = QueryBuilder() + val sql = qb.user().where { c -> + c.online.eq(SqlLit.bool(false)) + }.toSql() + assertTrue(sql.contains("FALSE"), "Should contain FALSE: $sql") + } + + // --- Identity hex literal in WHERE --- + + @Test + fun `Identity formats as hex literal in WHERE`() { + val id = Identity.fromHexString("ab".repeat(32)) + val qb = QueryBuilder() + val sql = qb.user().where { c -> + c.identity.eq(SqlLit.identity(id)) + }.toSql() + assertTrue(sql.contains("0x"), "Identity should be hex literal: $sql") + assertTrue(sql.contains("ab".repeat(32)), "Should contain hex value: $sql") + } + + // --- IxCol eq/neq formatting --- + + @Test + fun `IxCol eq generates correct SQL`() { + val qb = QueryBuilder() + val sql = qb.user().where { c -> + c.identity.eq(SqlLit.identity(Identity.zero())) + }.toSql() + assertTrue(sql.contains("\"identity\""), "Should reference identity column: $sql") + assertTrue(sql.contains("="), "Should contain = operator: $sql") + } + + @Test + fun `IxCol neq generates correct SQL`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> + c.id.neq(SqlLit.ulong(0UL)) + }.toSql() + assertTrue(sql.contains("<>"), "Should contain <> operator: $sql") + } + + // --- Table scan (no WHERE) produces SELECT * FROM table --- + + @Test + fun `table scan without where produces simple SELECT`() { + val qb = QueryBuilder() + val sql = qb.user().toSql() + assertEquals("SELECT * FROM \"user\"", sql) + } + + @Test + fun `different tables produce different SQL`() { + val qb = QueryBuilder() + val userSql = qb.user().toSql() + val noteSql = qb.note().toSql() + val messageSql = qb.message().toSql() + assertTrue(userSql.contains("\"user\""), "Should contain user table: $userSql") + assertTrue(noteSql.contains("\"note\""), "Should contain note table: $noteSql") + assertTrue(messageSql.contains("\"message\""), "Should contain message table: $messageSql") + } + + // --- Column name quoting --- + // Note: we can't create columns with quotes in our schema, but we can verify + // that existing column names are properly quoted + + @Test + fun `column names are double-quoted in SQL`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> + c.tag.eq(SqlLit.string("test")) + }.toSql() + assertTrue(sql.contains("\"tag\""), "Column should be double-quoted: $sql") + assertTrue(sql.contains("\"note\""), "Table should be double-quoted: $sql") + } + + @Test + fun `WHERE has table-qualified column names`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> + c.content.eq(SqlLit.string("x")) + }.toSql() + assertTrue(sql.contains("\"note\".\"content\""), "Column should be table-qualified: $sql") + } + + // --- Semijoin with WHERE on both sides --- + + @Test + fun `left semijoin with where on left table`() { + val qb = QueryBuilder() + val sql = qb.note() + .where { c -> c.tag.eq(SqlLit.string("important")) } + .leftSemijoin(qb.message()) { left, right -> left.id.eq(right.id) } + .toSql() + assertTrue(sql.contains("JOIN"), "Should contain JOIN: $sql") + assertTrue(sql.contains("\"note\".*"), "Should select note.*: $sql") + assertTrue(sql.contains("WHERE"), "Should contain WHERE: $sql") + assertTrue(sql.contains("'important'"), "Should contain left where value: $sql") + } + + @Test + fun `right semijoin selects right table columns`() { + val qb = QueryBuilder() + val sql = qb.note() + .rightSemijoin(qb.message()) { left, right -> left.id.eq(right.id) } + .toSql() + assertTrue(sql.contains("\"message\".*"), "Right semijoin should select message.*: $sql") + assertTrue(sql.contains("JOIN"), "Should contain JOIN: $sql") + } + + @Test + fun `left semijoin selects left table columns`() { + val qb = QueryBuilder() + val sql = qb.note() + .leftSemijoin(qb.message()) { left, right -> left.id.eq(right.id) } + .toSql() + assertTrue(sql.contains("\"note\".*"), "Left semijoin should select note.*: $sql") + } + + // --- Integer formatting --- + + @Test + fun `integer values format without locale separators`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> + c.id.gt(SqlLit.ulong(1000000UL)) + }.toSql() + assertTrue(sql.contains("1000000"), "Should not have locale separators: $sql") + assertTrue(!sql.contains("1,000,000"), "Should not have commas: $sql") + } + + // --- Empty string literal --- + + @Test + fun `empty string literal in WHERE`() { + val qb = QueryBuilder() + val sql = qb.note().where { c -> + c.tag.eq(SqlLit.string("")) + }.toSql() + assertTrue(sql.contains("''"), "Should contain empty string literal: $sql") + } + + // --- Chained where with filter alias --- + + @Test + fun `where then filter then where all chain with AND`() { + val qb = QueryBuilder() + val sql = qb.note() + .where { c -> c.tag.eq(SqlLit.string("a")) } + .filter { c -> c.content.eq(SqlLit.string("b")) } + .where { c -> c.id.gt(SqlLit.ulong(0UL)) } + .toSql() + val andCount = Regex("AND").findAll(sql).count() + assertTrue(andCount >= 2, "Should have at least 2 ANDs: $sql") + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ReducerCallbackOrderTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ReducerCallbackOrderTest.kt new file mode 100644 index 00000000000..862ce629e27 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ReducerCallbackOrderTest.kt @@ -0,0 +1,271 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Status +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.db +import module_bindings.reducers +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +/** + * Reducer/row callback interaction tests. + */ +class ReducerCallbackOrderTest { + + // --- Row callbacks fire during reducer event --- + + @Test + fun `onInsert fires during reducer callback`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val events = mutableListOf() + val done = CompletableDeferred() + + client.conn.db.note.onInsert { ctx, note -> + if (ctx !is EventContext.SubscribeApplied && note.owner == client.identity && note.tag == "order-test") { + events.add("onInsert") + } + } + + client.conn.reducers.onAddNote { ctx, _, _ -> + if (ctx.callerIdentity == client.identity) { + events.add("onReducer") + done.complete(Unit) + } + } + + client.conn.reducers.addNote("order-test-content", "order-test") + withTimeout(DEFAULT_TIMEOUT_MS) { done.await() } + + assertTrue(events.contains("onInsert"), "onInsert should have fired: $events") + assertTrue(events.contains("onReducer"), "onReducer should have fired: $events") + // Both should fire in the same transaction update + assertEquals(2, events.size, "Should have exactly 2 events: $events") + } + + // --- Failed reducer produces Status.Failed --- + + @Test + fun `failed reducer has Status Failed`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val status = CompletableDeferred() + client.conn.reducers.onAddNote { ctx, _, _ -> + if (ctx.callerIdentity == client.identity) { + status.complete(ctx.status) + } + } + + // Empty content triggers validation error + client.conn.reducers.addNote("", "fail-test") + val result = withTimeout(DEFAULT_TIMEOUT_MS) { status.await() } + assertTrue(result is Status.Failed, "Empty content should fail: $result") + } + + @Test + fun `failed reducer does not fire onInsert`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + var insertFired = false + client.conn.db.note.onInsert { ctx, note -> + if (ctx !is EventContext.SubscribeApplied && note.owner == client.identity && note.tag == "no-insert-test") { + insertFired = true + } + } + + val reducerDone = CompletableDeferred() + client.conn.reducers.onAddNote { ctx, _, _ -> + if (ctx.callerIdentity == client.identity) { + reducerDone.complete(Unit) + } + } + + // Empty content → validation error → no row inserted + client.conn.reducers.addNote("", "no-insert-test") + withTimeout(DEFAULT_TIMEOUT_MS) { reducerDone.await() } + kotlinx.coroutines.delay(200) + + assertTrue(!insertFired, "onInsert should NOT fire for failed reducer") + client.cleanup() + } + + @Test + fun `failed reducer error message is available`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val errorMsg = CompletableDeferred() + client.conn.reducers.onSendMessage { ctx, _ -> + if (ctx.callerIdentity == client.identity) { + val s = ctx.status + if (s is Status.Failed) { + errorMsg.complete(s.message) + } + } + } + + // Empty message triggers validation error + client.conn.reducers.sendMessage("") + val msg = withTimeout(DEFAULT_TIMEOUT_MS) { errorMsg.await() } + assertTrue(msg.contains("must not be empty"), "Error message should explain: $msg") + + client.cleanup() + } + + // --- onUpdate fires for modified row --- + + @Test + fun `onUpdate fires when row is modified`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + // Set initial name + val nameDone1 = CompletableDeferred() + client.conn.reducers.onSetName { ctx, _ -> + if (ctx.callerIdentity == client.identity && ctx.status is Status.Committed) { + nameDone1.complete(Unit) + } + } + val uniqueName1 = "update-test-${System.nanoTime()}" + client.conn.reducers.setName(uniqueName1) + withTimeout(DEFAULT_TIMEOUT_MS) { nameDone1.await() } + + // Register onUpdate, then change name again + val updateDone = CompletableDeferred>() + client.conn.db.user.onUpdate { ctx, oldRow, newRow -> + if (ctx !is EventContext.SubscribeApplied + && newRow.identity == client.identity + && oldRow.name == uniqueName1) { + updateDone.complete(oldRow.name to newRow.name) + } + } + + val uniqueName2 = "update-test2-${System.nanoTime()}" + client.conn.reducers.setName(uniqueName2) + val (oldName, newName) = withTimeout(DEFAULT_TIMEOUT_MS) { updateDone.await() } + + assertEquals(uniqueName1, oldName, "Old name should be first name") + assertEquals(uniqueName2, newName, "New name should be second name") + + client.cleanup() + } + + // --- Reducer callerIdentity matches connection --- + + @Test + fun `reducer context has correct callerIdentity`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val callerIdentity = CompletableDeferred() + client.conn.reducers.onAddNote { ctx, _, _ -> + if (ctx.callerIdentity == client.identity) { + callerIdentity.complete(ctx.callerIdentity) + } + } + + client.conn.reducers.addNote("identity-check", "id-test") + val identity = withTimeout(DEFAULT_TIMEOUT_MS) { callerIdentity.await() } + assertEquals(client.identity, identity) + + client.cleanup() + } + + @Test + fun `reducer context has reducerName`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val name = CompletableDeferred() + client.conn.reducers.onAddNote { ctx, _, _ -> + if (ctx.callerIdentity == client.identity) { + name.complete(ctx.reducerName) + } + } + + client.conn.reducers.addNote("name-check", "rn-test") + val reducerName = withTimeout(DEFAULT_TIMEOUT_MS) { name.await() } + assertEquals("add_note", reducerName) + + client.cleanup() + } + + @Test + fun `reducer context has args matching what was sent`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val argsContent = CompletableDeferred() + val argsTag = CompletableDeferred() + client.conn.reducers.onAddNote { ctx, content, tag -> + if (ctx.callerIdentity == client.identity) { + argsContent.complete(content) + argsTag.complete(tag) + } + } + + client.conn.reducers.addNote("specific-content-xyz", "specific-tag-abc") + assertEquals("specific-content-xyz", withTimeout(DEFAULT_TIMEOUT_MS) { argsContent.await() }) + assertEquals("specific-tag-abc", withTimeout(DEFAULT_TIMEOUT_MS) { argsTag.await() }) + + client.cleanup() + } + + // --- Multi-client: one client's reducer is observed by another --- + + @Test + fun `client B observes client A reducer via onInsert`() = runBlocking { + val clientA = connectToDb() + val clientB = connectToDb() + clientA.subscribeAll() + clientB.subscribeAll() + + val tag = "multi-client-${System.nanoTime()}" + + val bSawInsert = CompletableDeferred() + clientB.conn.db.note.onInsert { ctx, note -> + if (ctx !is EventContext.SubscribeApplied && note.tag == tag) { + bSawInsert.complete(true) + } + } + + clientA.conn.reducers.addNote("hello from A", tag) + val result = withTimeout(DEFAULT_TIMEOUT_MS) { bSawInsert.await() } + assertTrue(result, "Client B should see client A's insert") + + clientA.cleanup() + clientB.cleanup() + } + + @Test + fun `client B observes client A name change via onUpdate`() = runBlocking { + val clientA = connectToDb() + val clientB = connectToDb() + clientA.subscribeAll() + clientB.subscribeAll() + + val uniqueName = "multi-update-${System.nanoTime()}" + + val bSawUpdate = CompletableDeferred() + clientB.conn.db.user.onUpdate { ctx, _, newRow -> + if (ctx !is EventContext.SubscribeApplied && newRow.name == uniqueName) { + bSawUpdate.complete(newRow.name) + } + } + + clientA.conn.reducers.setName(uniqueName) + val name = withTimeout(DEFAULT_TIMEOUT_MS) { bSawUpdate.await() } + assertEquals(uniqueName, name) + + clientA.cleanup() + clientB.cleanup() + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/RemoveCallbacksTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/RemoveCallbacksTest.kt new file mode 100644 index 00000000000..fc97dc77334 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/RemoveCallbacksTest.kt @@ -0,0 +1,74 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.User +import module_bindings.db +import module_bindings.reducers +import kotlin.test.Test +import kotlin.test.assertTrue + +class RemoveCallbacksTest { + + @Test + fun `removeOnUpdate prevents callback from firing`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + var callbackFired = false + val cb: (EventContext, User, User) -> Unit = { _, _, _ -> callbackFired = true } + + client.conn.db.user.onUpdate(cb) + client.conn.db.user.removeOnUpdate(cb) + + // Trigger an update by setting name + val done = CompletableDeferred() + client.conn.reducers.onSetName { ctx, _ -> + if (ctx.callerIdentity == client.identity) done.complete(Unit) + } + client.conn.reducers.setName("removeOnUpdate-test-${System.nanoTime()}") + withTimeout(DEFAULT_TIMEOUT_MS) { done.await() } + + kotlinx.coroutines.delay(200) + assertTrue(!callbackFired, "Removed onUpdate callback should not fire") + + client.cleanup() + } + + @Test + fun `removeOnBeforeDelete prevents callback from firing`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + var callbackFired = false + val cb: (EventContext, module_bindings.Note) -> Unit = { _, _ -> callbackFired = true } + + client.conn.db.note.onBeforeDelete(cb) + client.conn.db.note.removeOnBeforeDelete(cb) + + // Insert then delete a note + val insertDone = CompletableDeferred() + client.conn.db.note.onInsert { ctx, note -> + if (ctx !is EventContext.SubscribeApplied + && note.owner == client.identity && note.tag == "rm-bd-test") { + insertDone.complete(note.id) + } + } + client.conn.reducers.addNote("removeOnBeforeDelete-test", "rm-bd-test") + val noteId = withTimeout(DEFAULT_TIMEOUT_MS) { insertDone.await() } + + val delDone = CompletableDeferred() + client.conn.reducers.onDeleteNote { ctx, _ -> + if (ctx.callerIdentity == client.identity) delDone.complete(Unit) + } + client.conn.reducers.deleteNote(noteId) + withTimeout(DEFAULT_TIMEOUT_MS) { delDone.await() } + + kotlinx.coroutines.delay(200) + assertTrue(!callbackFired, "Removed onBeforeDelete callback should not fire") + + client.cleanup() + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ScheduleAtTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ScheduleAtTest.kt new file mode 100644 index 00000000000..53b47e69119 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/ScheduleAtTest.kt @@ -0,0 +1,91 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ScheduleAt +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.TimeDuration +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.minutes +import kotlin.time.Duration.Companion.seconds +import kotlin.time.Instant + +class ScheduleAtTest { + + // --- interval factory --- + + @Test + fun `interval creates Interval variant`() { + val schedule = ScheduleAt.interval(5.seconds) + assertTrue(schedule is ScheduleAt.Interval, "Should be Interval, got: ${schedule::class.simpleName}") + } + + @Test + fun `interval preserves duration`() { + val schedule = ScheduleAt.interval(5.seconds) as ScheduleAt.Interval + assertEquals(5000L, schedule.duration.millis) + } + + @Test + fun `interval with minutes`() { + val schedule = ScheduleAt.interval(2.minutes) as ScheduleAt.Interval + assertEquals(120_000L, schedule.duration.millis) + } + + // --- time factory --- + + @Test + fun `time creates Time variant`() { + val instant = Instant.fromEpochMilliseconds(System.currentTimeMillis()) + val schedule = ScheduleAt.time(instant) + assertTrue(schedule is ScheduleAt.Time, "Should be Time, got: ${schedule::class.simpleName}") + } + + @Test + fun `time preserves instant`() { + val millis = System.currentTimeMillis() + val instant = Instant.fromEpochMilliseconds(millis) + val schedule = ScheduleAt.time(instant) as ScheduleAt.Time + assertEquals(millis, schedule.timestamp.millisSinceUnixEpoch) + } + + // --- Direct constructors --- + + @Test + fun `Interval constructor with TimeDuration`() { + val dur = TimeDuration.fromMillis(3000L) + val schedule = ScheduleAt.Interval(dur) + assertEquals(3000L, schedule.duration.millis) + } + + @Test + fun `Time constructor with Timestamp`() { + val ts = Timestamp.fromMillis(42000L) + val schedule = ScheduleAt.Time(ts) + assertEquals(42000L, schedule.timestamp.millisSinceUnixEpoch) + } + + // --- Equality --- + + @Test + fun `Interval equality`() { + val a = ScheduleAt.interval(5.seconds) + val b = ScheduleAt.interval(5.seconds) + assertEquals(a, b) + } + + @Test + fun `Time equality`() { + val instant = Instant.fromEpochMilliseconds(1000L) + val a = ScheduleAt.time(instant) + val b = ScheduleAt.time(instant) + assertEquals(a, b) + } + + @Test + fun `Interval and Time are not equal`() { + val interval = ScheduleAt.interval(1.seconds) + val time = ScheduleAt.time(Instant.fromEpochMilliseconds(1000L)) + assertTrue(interval != time, "Interval and Time should not be equal") + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SpacetimeTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SpacetimeTest.kt new file mode 100644 index 00000000000..1a4cec98ea8 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SpacetimeTest.kt @@ -0,0 +1,98 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import io.ktor.client.HttpClient +import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.plugins.websocket.WebSockets +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.withTimeout +import module_bindings.db +import module_bindings.reducers +import module_bindings.withModuleBindings +import java.net.Socket + +val HOST: String = System.getenv("SPACETIMEDB_HOST") ?: "ws://localhost:3000" +val DB_NAME: String = System.getenv("SPACETIMEDB_DB_NAME") ?: "chat-all" +const val DEFAULT_TIMEOUT_MS = 10_000L + +private fun checkServerReachable() { + val url = java.net.URI(HOST.replace("ws://", "http://").replace("wss://", "https://")) + val host = url.host ?: "localhost" + val port = if (url.port > 0) url.port else 3000 + try { + Socket().use { it.connect(java.net.InetSocketAddress(host, port), 2000) } + } catch (_: Exception) { + throw AssertionError( + "SpacetimeDB server is not reachable at $host:$port. " + + "Start it with: spacetimedb-cli start" + ) + } +} + +fun createTestHttpClient(): HttpClient = HttpClient(OkHttp) { + install(WebSockets) +} + +data class ConnectedClient( + val conn: DbConnection, + val identity: Identity, + val token: String, +) + +suspend fun connectToDb(token: String? = null): ConnectedClient { + checkServerReachable() + val identityDeferred = CompletableDeferred>() + + val connection = DbConnection.Builder() + .withHttpClient(createTestHttpClient()) + .withUri(HOST) + .withDatabaseName(DB_NAME) + .withToken(token) + .withModuleBindings() + .onConnect { _, identity, tok -> + identityDeferred.complete(identity to tok) + } + .onConnectError { _, e -> + identityDeferred.completeExceptionally(e) + } + .build() + + val (identity, tok) = withTimeout(DEFAULT_TIMEOUT_MS) { identityDeferred.await() } + return ConnectedClient(conn = connection, identity = identity, token = tok) +} + +suspend fun ConnectedClient.subscribeAll(): ConnectedClient { + val applied = CompletableDeferred() + conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .subscribe(listOf( + "SELECT * FROM user", + "SELECT * FROM message", + "SELECT * FROM note", + "SELECT * FROM reminder", + "SELECT * FROM big_int_row", + )) + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + return this +} + +suspend fun ConnectedClient.cleanup() { + for (msg in conn.db.message.all()) { + if (msg.sender == identity) { + conn.reducers.deleteMessage(msg.id) + } + } + for (note in conn.db.note.all()) { + if (note.owner == identity) { + conn.reducers.deleteNote(note.id) + } + } + for (reminder in conn.db.reminder.all()) { + if (reminder.owner == identity) { + conn.reducers.cancelReminder(reminder.scheduledId) + } + } + conn.disconnect() +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SpacetimeUuidTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SpacetimeUuidTest.kt new file mode 100644 index 00000000000..35e03f887ed --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SpacetimeUuidTest.kt @@ -0,0 +1,153 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Counter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.SpacetimeUuid +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.UuidVersion +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotEquals +import kotlin.test.assertTrue +import kotlin.time.Instant + +class SpacetimeUuidTest { + + @Test + fun `NIL uuid has version Nil`() { + assertEquals(UuidVersion.Nil, SpacetimeUuid.NIL.getVersion()) + } + + @Test + fun `MAX uuid has version Max`() { + assertEquals(UuidVersion.Max, SpacetimeUuid.MAX.getVersion()) + } + + @Test + fun `random produces V4 uuid`() { + val uuid = SpacetimeUuid.random() + assertEquals(UuidVersion.V4, uuid.getVersion()) + } + + @Test + fun `random produces unique values`() { + val a = SpacetimeUuid.random() + val b = SpacetimeUuid.random() + assertNotEquals(a, b) + } + + @Test + fun `parse roundtrips through toString`() { + val uuid = SpacetimeUuid.random() + val str = uuid.toString() + val parsed = SpacetimeUuid.parse(str) + assertEquals(uuid, parsed) + } + + @Test + fun `parse invalid string throws`() { + assertFailsWith { + SpacetimeUuid.parse("not-a-uuid") + } + } + + @Test + fun `toHexString returns 32 lowercase hex chars`() { + val uuid = SpacetimeUuid.random() + val hex = uuid.toHexString() + assertEquals(32, hex.length, "Hex string should be 32 chars: $hex") + assertTrue(hex.all { it in '0'..'9' || it in 'a'..'f' }, "Should be lowercase hex: $hex") + } + + @Test + fun `toByteArray returns 16 bytes`() { + val uuid = SpacetimeUuid.random() + assertEquals(16, uuid.toByteArray().size) + } + + @Test + fun `NIL and MAX are distinct`() { + assertNotEquals(SpacetimeUuid.NIL, SpacetimeUuid.MAX) + } + + @Test + fun `compareTo orders NIL before MAX`() { + assertTrue(SpacetimeUuid.NIL < SpacetimeUuid.MAX) + } + + @Test + fun `compareTo is reflexive`() { + val uuid = SpacetimeUuid.random() + assertEquals(0, uuid.compareTo(uuid)) + } + + @Test + fun `fromRandomBytesV4 produces V4 uuid`() { + val bytes = ByteArray(16) { it.toByte() } + val uuid = SpacetimeUuid.fromRandomBytesV4(bytes) + assertEquals(UuidVersion.V4, uuid.getVersion()) + } + + @Test + fun `fromRandomBytesV4 rejects wrong size`() { + assertFailsWith { + SpacetimeUuid.fromRandomBytesV4(ByteArray(8)) + } + } + + @Test + fun `fromCounterV7 produces V7 uuid`() { + val counter = Counter(0) + val now = Timestamp(Instant.fromEpochMilliseconds(System.currentTimeMillis())) + val randomBytes = ByteArray(4) { 0x42 } + val uuid = SpacetimeUuid.fromCounterV7(counter, now, randomBytes) + assertEquals(UuidVersion.V7, uuid.getVersion()) + } + + @Test + fun `fromCounterV7 increments counter`() { + val counter = Counter(0) + val now = Timestamp(Instant.fromEpochMilliseconds(System.currentTimeMillis())) + val randomBytes = ByteArray(4) { 0x42 } + + val a = SpacetimeUuid.fromCounterV7(counter, now, randomBytes) + val b = SpacetimeUuid.fromCounterV7(counter, now, randomBytes) + assertNotEquals(a, b, "Sequential V7 UUIDs should differ due to counter") + assertTrue(a.getCounter() < b.getCounter(), "Counter should increment") + } + + @Test + fun `fromCounterV7 rejects too few random bytes`() { + val counter = Counter(0) + val now = Timestamp(Instant.fromEpochMilliseconds(System.currentTimeMillis())) + assertFailsWith { + SpacetimeUuid.fromCounterV7(counter, now, ByteArray(2)) + } + } + + @Test + fun `getCounter returns embedded counter value`() { + val counter = Counter(42) + val now = Timestamp(Instant.fromEpochMilliseconds(System.currentTimeMillis())) + val uuid = SpacetimeUuid.fromCounterV7(counter, now, ByteArray(4) { 0 }) + assertEquals(42, uuid.getCounter(), "getCounter should return the embedded counter") + } + + @Test + fun `equals and hashCode are consistent`() { + val uuid = SpacetimeUuid.random() + val same = SpacetimeUuid.parse(uuid.toString()) + assertEquals(uuid, same) + assertEquals(uuid.hashCode(), same.hashCode()) + } + + @Test + fun `NIL toHexString is all zeros`() { + assertEquals("00000000000000000000000000000000", SpacetimeUuid.NIL.toHexString()) + } + + @Test + fun `MAX toHexString is all f`() { + assertEquals("ffffffffffffffffffffffffffffffff", SpacetimeUuid.MAX.toHexString()) + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SqlFormatTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SqlFormatTest.kt new file mode 100644 index 00000000000..11c40408fd8 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SqlFormatTest.kt @@ -0,0 +1,121 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SqlFormat +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class SqlFormatTest { + + // --- quoteIdent --- + + @Test + fun `quoteIdent wraps in double quotes`() { + assertEquals("\"tableName\"", SqlFormat.quoteIdent("tableName")) + } + + @Test + fun `quoteIdent escapes internal double quotes`() { + assertEquals("\"bad\"\"name\"", SqlFormat.quoteIdent("bad\"name")) + } + + @Test + fun `quoteIdent handles empty string`() { + assertEquals("\"\"", SqlFormat.quoteIdent("")) + } + + @Test + fun `quoteIdent with multiple double quotes`() { + assertEquals("\"a\"\"b\"\"c\"", SqlFormat.quoteIdent("a\"b\"c")) + } + + @Test + fun `quoteIdent preserves spaces`() { + assertEquals("\"my table\"", SqlFormat.quoteIdent("my table")) + } + + // --- formatStringLiteral --- + + @Test + fun `formatStringLiteral wraps in single quotes`() { + assertEquals("'hello'", SqlFormat.formatStringLiteral("hello")) + } + + @Test + fun `formatStringLiteral escapes single quotes`() { + assertEquals("'O''Brien'", SqlFormat.formatStringLiteral("O'Brien")) + } + + @Test + fun `formatStringLiteral empty string`() { + assertEquals("''", SqlFormat.formatStringLiteral("")) + } + + @Test + fun `formatStringLiteral multiple single quotes`() { + assertEquals("'it''s a ''test'''", SqlFormat.formatStringLiteral("it's a 'test'")) + } + + @Test + fun `formatStringLiteral preserves double quotes`() { + assertEquals("'say \"hi\"'", SqlFormat.formatStringLiteral("say \"hi\"")) + } + + @Test + fun `formatStringLiteral preserves special chars`() { + assertEquals("'tab\tnewline\n'", SqlFormat.formatStringLiteral("tab\tnewline\n")) + } + + // --- formatHexLiteral --- + + @Test + fun `formatHexLiteral adds 0x prefix`() { + assertEquals("0x01020304", SqlFormat.formatHexLiteral("01020304")) + } + + @Test + fun `formatHexLiteral strips existing 0x prefix`() { + assertEquals("0xabcdef", SqlFormat.formatHexLiteral("0xabcdef")) + } + + @Test + fun `formatHexLiteral strips 0X prefix case insensitive`() { + assertEquals("0xABCDEF", SqlFormat.formatHexLiteral("0XABCDEF")) + } + + @Test + fun `formatHexLiteral strips hyphens`() { + assertEquals("0x0123456789ab", SqlFormat.formatHexLiteral("01234567-89ab")) + } + + @Test + fun `formatHexLiteral accepts uppercase hex`() { + assertEquals("0xABCD", SqlFormat.formatHexLiteral("ABCD")) + } + + @Test + fun `formatHexLiteral accepts mixed case hex`() { + assertEquals("0xAbCd", SqlFormat.formatHexLiteral("AbCd")) + } + + @Test + fun `formatHexLiteral rejects non-hex chars`() { + assertFailsWith { + SqlFormat.formatHexLiteral("xyz123") + } + } + + @Test + fun `formatHexLiteral rejects empty after prefix strip`() { + assertFailsWith { + SqlFormat.formatHexLiteral("0x") + } + } + + @Test + fun `formatHexLiteral rejects empty string`() { + assertFailsWith { + SqlFormat.formatHexLiteral("") + } + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SqlLitTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SqlLitTest.kt new file mode 100644 index 00000000000..fa84774f9df --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SqlLitTest.kt @@ -0,0 +1,159 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SqlLit +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.SpacetimeUuid +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class SqlLitTest { + + // --- String --- + + @Test + fun `string literal wraps in quotes`() { + val lit = SqlLit.string("hello") + assertTrue(lit.sql.startsWith("'"), "Should start with quote: ${lit.sql}") + assertTrue(lit.sql.endsWith("'"), "Should end with quote: ${lit.sql}") + assertTrue(lit.sql.contains("hello"), "Should contain value: ${lit.sql}") + } + + @Test + fun `string literal escapes single quotes`() { + val lit = SqlLit.string("it's") + // SQL standard: single quotes are escaped by doubling them + assertTrue(lit.sql.contains("''"), "Should escape single quote: ${lit.sql}") + } + + @Test + fun `string literal handles empty string`() { + val lit = SqlLit.string("") + assertEquals("''", lit.sql, "Empty string should be two quotes") + } + + // --- Bool --- + + @Test + fun `bool true literal`() { + assertEquals("TRUE", SqlLit.bool(true).sql) + } + + @Test + fun `bool false literal`() { + assertEquals("FALSE", SqlLit.bool(false).sql) + } + + // --- Numeric types --- + + @Test + fun `byte literal`() { + assertEquals("42", SqlLit.byte(42).sql) + assertEquals("-128", SqlLit.byte(Byte.MIN_VALUE).sql) + assertEquals("127", SqlLit.byte(Byte.MAX_VALUE).sql) + } + + @Test + fun `ubyte literal`() { + assertEquals("0", SqlLit.ubyte(0u).sql) + assertEquals("255", SqlLit.ubyte(UByte.MAX_VALUE).sql) + } + + @Test + fun `short literal`() { + assertEquals("1000", SqlLit.short(1000).sql) + assertEquals("-32768", SqlLit.short(Short.MIN_VALUE).sql) + } + + @Test + fun `ushort literal`() { + assertEquals("0", SqlLit.ushort(0u).sql) + assertEquals("65535", SqlLit.ushort(UShort.MAX_VALUE).sql) + } + + @Test + fun `int literal`() { + assertEquals("42", SqlLit.int(42).sql) + assertEquals("0", SqlLit.int(0).sql) + assertEquals("-1", SqlLit.int(-1).sql) + } + + @Test + fun `uint literal`() { + assertEquals("0", SqlLit.uint(0u).sql) + assertEquals("4294967295", SqlLit.uint(UInt.MAX_VALUE).sql) + } + + @Test + fun `long literal`() { + assertEquals("0", SqlLit.long(0L).sql) + assertEquals("9223372036854775807", SqlLit.long(Long.MAX_VALUE).sql) + } + + @Test + fun `ulong literal`() { + assertEquals("0", SqlLit.ulong(0uL).sql) + assertEquals("18446744073709551615", SqlLit.ulong(ULong.MAX_VALUE).sql) + } + + @Test + fun `float literal`() { + val lit = SqlLit.float(3.14f) + assertTrue(lit.sql.startsWith("3.14"), "Float should contain value: ${lit.sql}") + } + + @Test + fun `double literal`() { + assertEquals("3.14", SqlLit.double(3.14).sql) + } + + // --- Identity / ConnectionId / UUID --- + + @Test + fun `identity literal is hex formatted`() { + val identity = Identity.zero() + val lit = SqlLit.identity(identity) + assertTrue(lit.sql.isNotEmpty(), "Identity literal should not be empty: ${lit.sql}") + // Zero identity => all zeros hex + assertTrue(lit.sql.contains("0".repeat(32)), "Zero identity should contain zeros: ${lit.sql}") + } + + @Test + fun `identity literal from hex string`() { + val hex = "ab".repeat(32) // 64 hex chars for 32-byte U256 + val identity = Identity.fromHexString(hex) + val lit = SqlLit.identity(identity) + assertTrue(lit.sql.contains("ab"), "Should contain hex value: ${lit.sql}") + } + + @Test + fun `connectionId literal is hex formatted`() { + val connId = ConnectionId.zero() + val lit = SqlLit.connectionId(connId) + assertTrue(lit.sql.isNotEmpty(), "ConnectionId literal should not be empty: ${lit.sql}") + } + + @Test + fun `connectionId literal from random`() { + val connId = ConnectionId.random() + val lit = SqlLit.connectionId(connId) + assertTrue(lit.sql.isNotEmpty(), "Random connectionId literal should not be empty: ${lit.sql}") + } + + @Test + fun `uuid literal is hex formatted`() { + val uuid = SpacetimeUuid.NIL + val lit = SqlLit.uuid(uuid) + assertTrue(lit.sql.contains("0".repeat(32)), "NIL uuid should be all zeros: ${lit.sql}") + } + + @Test + fun `uuid literal for random uuid`() { + val uuid = SpacetimeUuid.random() + val lit = SqlLit.uuid(uuid) + assertTrue(lit.sql.isNotEmpty(), "UUID literal should not be empty") + // Hex literal format is typically 0x... or X'...' + assertTrue(lit.sql.length > 32, "UUID literal should contain hex representation: ${lit.sql}") + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/StatsTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/StatsTest.kt new file mode 100644 index 00000000000..4345c057e98 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/StatsTest.kt @@ -0,0 +1,159 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.reducers +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class StatsTest { + + @Test + fun `stats are zero before any operations`() = runBlocking { + val client = connectToDb() + + assertEquals(0, client.conn.stats.reducerRequestTracker.sampleCount, "reducer samples should be 0 initially") + assertEquals(0, client.conn.stats.oneOffRequestTracker.sampleCount, "oneOff samples should be 0 initially") + assertEquals(0, client.conn.stats.reducerRequestTracker.requestsAwaitingResponse, "no in-flight requests initially") + assertNull(client.conn.stats.reducerRequestTracker.allTimeMinMax, "allTimeMinMax should be null initially") + + client.conn.disconnect() + } + + @Test + fun `subscriptionRequestTracker increments after subscribe`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val subSamples = client.conn.stats.subscriptionRequestTracker.sampleCount + assertTrue(subSamples > 0, "subscriptionRequestTracker should have samples after subscribe, got $subSamples") + + client.conn.disconnect() + } + + @Test + fun `reducerRequestTracker increments after reducer call`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val before = client.conn.stats.reducerRequestTracker.sampleCount + + val reducerDone = CompletableDeferred() + client.conn.reducers.onSendMessage { ctx, _ -> + if (ctx.callerIdentity == client.identity) reducerDone.complete(Unit) + } + client.conn.reducers.sendMessage("stats-reducer-${System.nanoTime()}") + withTimeout(DEFAULT_TIMEOUT_MS) { reducerDone.await() } + + val after = client.conn.stats.reducerRequestTracker.sampleCount + assertTrue(after > before, "reducerRequestTracker should increment, before=$before after=$after") + + client.cleanup() + } + + @Test + fun `oneOffRequestTracker increments after suspend query`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val before = client.conn.stats.oneOffRequestTracker.sampleCount + + client.conn.oneOffQuery("SELECT * FROM user") + + val after = client.conn.stats.oneOffRequestTracker.sampleCount + assertTrue(after > before, "oneOffRequestTracker should increment, before=$before after=$after") + + client.conn.disconnect() + } + + @Test + fun `allTimeMinMax is set after reducer call`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val reducerDone = CompletableDeferred() + client.conn.reducers.onSendMessage { ctx, _ -> + if (ctx.callerIdentity == client.identity) reducerDone.complete(Unit) + } + client.conn.reducers.sendMessage("stats-minmax-${System.nanoTime()}") + withTimeout(DEFAULT_TIMEOUT_MS) { reducerDone.await() } + + val minMax = assertNotNull(client.conn.stats.reducerRequestTracker.allTimeMinMax, "allTimeMinMax should be set") + assertTrue( + minMax.min.duration >= kotlin.time.Duration.ZERO, + "min duration should be non-negative" + ) + + client.cleanup() + } + + @Test + fun `minMaxTimes returns null when no window has rotated`() = runBlocking { + val client = connectToDb() + + // On a fresh tracker, no window has rotated yet + val minMax = client.conn.stats.reducerRequestTracker.minMaxTimes(60) + assertNull(minMax, "minMaxTimes should return null before any window rotation") + + client.conn.disconnect() + } + + @Test + fun `procedureRequestTracker exists and starts empty`() = runBlocking { + val client = connectToDb() + + val tracker = client.conn.stats.procedureRequestTracker + assertEquals(0, tracker.sampleCount, "No procedures called, sample count should be 0") + assertNull(tracker.allTimeMinMax, "No procedures called, allTimeMinMax should be null") + assertEquals(0, tracker.requestsAwaitingResponse, "No procedures in flight") + + client.conn.disconnect() + } + + @Test + fun `applyMessageTracker exists`() = runBlocking { + val client = connectToDb() + + val tracker = client.conn.stats.applyMessageTracker + // After connecting, there may or may not be apply messages depending on timing + assertTrue(tracker.sampleCount >= 0, "Sample count should be non-negative") + assertTrue(tracker.requestsAwaitingResponse >= 0, "Awaiting should be non-negative") + + client.conn.disconnect() + } + + @Test + fun `applyMessageTracker records after subscription`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val tracker = client.conn.stats.applyMessageTracker + // After subscribing, server applies the subscription which should register + assertTrue(tracker.sampleCount >= 0, "Sample count should be non-negative after subscribe") + + client.conn.disconnect() + } + + @Test + fun `all five trackers are distinct objects`() = runBlocking { + val client = connectToDb() + + val stats = client.conn.stats + val trackers = listOf( + stats.reducerRequestTracker, + stats.subscriptionRequestTracker, + stats.oneOffRequestTracker, + stats.procedureRequestTracker, + stats.applyMessageTracker, + ) + // All should be distinct instances + val unique = trackers.toSet() + assertEquals(5, unique.size, "All 5 trackers should be distinct objects") + + client.conn.disconnect() + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SubscriptionBuilderTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SubscriptionBuilderTest.kt new file mode 100644 index 00000000000..eb98b3c9695 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/SubscriptionBuilderTest.kt @@ -0,0 +1,256 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SubscriptionError +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SubscriptionState +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.db +import module_bindings.subscribeToAllTables +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class SubscriptionBuilderTest { + + @Test + fun `addQuery with subscribe builds multi-query subscription`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .addQuery("SELECT * FROM user") + .addQuery("SELECT * FROM message") + .subscribe() + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + val users = client.conn.db.user.all() + assertTrue(users.isNotEmpty(), "Should see at least our own user after subscribe") + + client.conn.disconnect() + } + + @Test + fun `subscribeToAllTables subscribes to every table`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .subscribeToAllTables() + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + val users = client.conn.db.user.all() + assertTrue(users.isNotEmpty(), "Should see at least our own user after subscribeToAllTables") + + client.conn.disconnect() + } + + @Test + fun `subscribe with no queries throws`() = runBlocking { + val client = connectToDb() + + assertFailsWith { + client.conn.subscriptionBuilder() + .onApplied { _ -> } + .subscribe() + } + + client.conn.disconnect() + } + + @Test + fun `onError fires on invalid SQL`() = runBlocking { + val client = connectToDb() + val error = CompletableDeferred() + + client.conn.subscriptionBuilder() + .onApplied { _ -> error.completeExceptionally(AssertionError("Should not apply invalid SQL")) } + .onError { _, err -> error.complete(err) } + .subscribe("THIS IS NOT VALID SQL") + + val err = withTimeout(DEFAULT_TIMEOUT_MS) { error.await() } + assertTrue(err is SubscriptionError.ServerError, "Should be ServerError") + assertTrue(err.message.isNotEmpty(), "Error message should be non-empty: ${err.message}") + + client.conn.disconnect() + } + + @Test + fun `multiple onApplied callbacks all fire`() = runBlocking { + val client = connectToDb() + val first = CompletableDeferred() + val second = CompletableDeferred() + + client.conn.subscriptionBuilder() + .onApplied { _ -> first.complete(Unit) } + .onApplied { _ -> second.complete(Unit) } + .onError { _, err -> first.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM user") + + withTimeout(DEFAULT_TIMEOUT_MS) { first.await() } + withTimeout(DEFAULT_TIMEOUT_MS) { second.await() } + + client.conn.disconnect() + } + + @Test + fun `subscription handle state transitions from pending to active`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + val handle = client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM user") + + // Immediately after subscribe(), handle should be pending + // (may already be active if server responds fast, so check both) + assertTrue( + handle.state == SubscriptionState.PENDING || handle.state == SubscriptionState.ACTIVE, + "State should be PENDING or ACTIVE immediately after subscribe, got: ${handle.state}" + ) + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + assertEquals(SubscriptionState.ACTIVE, handle.state, "State should be ACTIVE after onApplied") + assertTrue(handle.isActive, "isActive should be true") + assertFalse(handle.isPending, "isPending should be false") + + client.conn.disconnect() + } + + @Test + fun `unsubscribeThen transitions handle to ended`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + val handle = client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM note") + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + assertTrue(handle.isActive) + + val unsubDone = CompletableDeferred() + handle.unsubscribeThen { _ -> unsubDone.complete(Unit) } + withTimeout(DEFAULT_TIMEOUT_MS) { unsubDone.await() } + + assertEquals(SubscriptionState.ENDED, handle.state, "State should be ENDED after unsubscribe") + assertFalse(handle.isActive, "isActive should be false after unsubscribe") + } + + @Test + fun `queries contains the subscribed query`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + val handle = client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM user") + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + + assertEquals(1, handle.queries.size, "Should have 1 query") + assertEquals("SELECT * FROM user", handle.queries[0]) + + client.conn.disconnect() + } + + @Test + fun `queries contains multiple subscribed queries`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + val handle = client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .addQuery("SELECT * FROM user") + .addQuery("SELECT * FROM note") + .subscribe() + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + + assertEquals(2, handle.queries.size, "Should have 2 queries") + assertTrue(handle.queries.contains("SELECT * FROM user")) + assertTrue(handle.queries.contains("SELECT * FROM note")) + + client.conn.disconnect() + } + + @Test + fun `isUnsubscribing is false while active`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + val handle = client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM user") + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + assertFalse(handle.isUnsubscribing, "Should not be unsubscribing while active") + + client.conn.disconnect() + } + + @Test + fun `isEnded is false while active`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + val handle = client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM user") + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + assertFalse(handle.isEnded, "Should not be ended while active") + + client.conn.disconnect() + } + + @Test + fun `isEnded is true after unsubscribeThen completes`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + val handle = client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM note") + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + + val unsubDone = CompletableDeferred() + handle.unsubscribeThen { _ -> unsubDone.complete(Unit) } + withTimeout(DEFAULT_TIMEOUT_MS) { unsubDone.await() } + + assertTrue(handle.isEnded, "Should be ended after unsubscribe") + assertEquals(SubscriptionState.ENDED, handle.state) + } + + @Test + fun `querySetId is assigned`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + val handle = client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM user") + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + + val id = handle.querySetId + assertTrue(id.id >= 0u, "querySetId should be non-negative: ${id.id}") + + client.conn.disconnect() + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TableCacheTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TableCacheTest.kt new file mode 100644 index 00000000000..a0a1200cbfe --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TableCacheTest.kt @@ -0,0 +1,227 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.db +import module_bindings.reducers +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class TableCacheTest { + + @Test + fun `count returns number of cached rows`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val count = client.conn.db.user.count() + assertTrue(count > 0, "Should have at least 1 user (ourselves), got $count") + + client.conn.disconnect() + } + + @Test + fun `count is zero before subscribe`() = runBlocking { + val client = connectToDb() + + // Before subscribing, cache should be empty + assertEquals(0, client.conn.db.note.count(), "count should be 0 before subscribe") + assertTrue(client.conn.db.note.all().isEmpty(), "all() should be empty before subscribe") + assertFalse(client.conn.db.note.iter().any(), "iter() should have no elements before subscribe") + + client.conn.disconnect() + } + + @Test + fun `count updates after insert`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val before = client.conn.db.note.count() + + val insertDone = CompletableDeferred() + client.conn.db.note.onInsert { ctx, note -> + if (ctx !is EventContext.SubscribeApplied + && note.owner == client.identity && note.tag == "count-test") { + insertDone.complete(Unit) + } + } + client.conn.reducers.addNote("count-test-content", "count-test") + withTimeout(DEFAULT_TIMEOUT_MS) { insertDone.await() } + + val after = client.conn.db.note.count() + assertEquals(before + 1, after, "count should increment by 1 after insert") + + client.cleanup() + } + + @Test + fun `iter iterates over cached rows`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val first = client.conn.db.user.iter().firstOrNull() + assertNotNull(first, "iter() should have at least one element") + + client.conn.disconnect() + } + + @Test + fun `all returns list of cached rows`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val all = client.conn.db.user.all() + assertTrue(all.isNotEmpty(), "all() should return non-empty list") + assertEquals(client.conn.db.user.count(), all.size, "all().size should match count()") + + client.conn.disconnect() + } + + @Test + fun `all and count are consistent with iter`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + val all = client.conn.db.user.all() + val count = client.conn.db.user.count() + val iterCount = client.conn.db.user.iter().count() + + assertEquals(count, all.size, "all().size should match count()") + assertEquals(count, iterCount, "iter count should match count()") + + client.conn.disconnect() + } + + @Test + fun `UniqueIndex find returns row by key`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + // Look up our own user by identity (UniqueIndex) + val user = client.conn.db.user.identity.find(client.identity) + assertNotNull(user, "Should find our own user by identity") + assertTrue(user.online, "Our user should be online") + + client.conn.disconnect() + } + + @Test + fun `UniqueIndex find returns null for missing key`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + // Note.id UniqueIndex — look up non-existent id + val note = client.conn.db.note.id.find(ULong.MAX_VALUE) + assertEquals(null, note, "Should return null for non-existent key") + + client.conn.disconnect() + } + + @Test + fun `removeOnInsert prevents callback from firing`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + var callbackFired = false + val cb: (EventContext, module_bindings.Note) -> Unit = + { _, _ -> callbackFired = true } + + client.conn.db.note.onInsert(cb) + client.conn.db.note.removeOnInsert(cb) + + // Insert a note — the removed callback should NOT fire + val done = CompletableDeferred() + client.conn.reducers.onAddNote { ctx, _, _ -> + if (ctx.callerIdentity == client.identity) done.complete(Unit) + } + client.conn.reducers.addNote("remove-insert-test", "test") + withTimeout(DEFAULT_TIMEOUT_MS) { done.await() } + + // Small delay to ensure callback would have fired if registered + kotlinx.coroutines.delay(200) + assertTrue(!callbackFired, "Removed onInsert callback should not fire") + + client.cleanup() + } + + @Test + fun `removeOnDelete prevents callback from firing`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + var callbackFired = false + val cb: (EventContext, module_bindings.Note) -> Unit = + { _, _ -> callbackFired = true } + + client.conn.db.note.onDelete(cb) + client.conn.db.note.removeOnDelete(cb) + + // Insert then delete a note + val insertDone = CompletableDeferred() + client.conn.db.note.onInsert { ctx, note -> + if (ctx !is EventContext.SubscribeApplied + && note.owner == client.identity && note.tag == "rm-del-test") { + insertDone.complete(note.id) + } + } + client.conn.reducers.addNote("remove-delete-test", "rm-del-test") + val noteId = withTimeout(DEFAULT_TIMEOUT_MS) { insertDone.await() } + + val delDone = CompletableDeferred() + client.conn.reducers.onDeleteNote { ctx, _ -> + if (ctx.callerIdentity == client.identity) delDone.complete(Unit) + } + client.conn.reducers.deleteNote(noteId) + withTimeout(DEFAULT_TIMEOUT_MS) { delDone.await() } + + kotlinx.coroutines.delay(200) + assertTrue(!callbackFired, "Removed onDelete callback should not fire") + + client.cleanup() + } + + @Test + fun `onBeforeDelete fires before row is removed from cache`() = runBlocking { + val client = connectToDb() + client.subscribeAll() + + // Insert a note first + val insertDone = CompletableDeferred() + client.conn.db.note.onInsert { ctx, note -> + if (ctx !is EventContext.SubscribeApplied + && note.owner == client.identity && note.tag == "before-del-test") { + insertDone.complete(note.id) + } + } + client.conn.reducers.addNote("before-delete-test", "before-del-test") + val noteId = withTimeout(DEFAULT_TIMEOUT_MS) { insertDone.await() } + + // Register onBeforeDelete — row should still be in cache when this fires + val beforeDeleteFired = CompletableDeferred() + client.conn.db.note.onBeforeDelete { _, note -> + if (note.id == noteId) { + // Check if the row is still findable in cache + val stillInCache = client.conn.db.note.id.find(noteId) != null + beforeDeleteFired.complete(stillInCache) + } + } + + val delDone = CompletableDeferred() + client.conn.reducers.onDeleteNote { ctx, _ -> + if (ctx.callerIdentity == client.identity) delDone.complete(Unit) + } + client.conn.reducers.deleteNote(noteId) + withTimeout(DEFAULT_TIMEOUT_MS) { delDone.await() } + + val wasStillInCache = withTimeout(DEFAULT_TIMEOUT_MS) { beforeDeleteFired.await() } + assertTrue(wasStillInCache, "Row should still be in cache during onBeforeDelete") + + client.cleanup() + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TimeDurationTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TimeDurationTest.kt new file mode 100644 index 00000000000..4324d5093d7 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TimeDurationTest.kt @@ -0,0 +1,157 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.TimeDuration +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.microseconds +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds + +class TimeDurationTest { + + // --- Factory --- + + @Test + fun `fromMillis creates correct duration`() { + val d = TimeDuration.fromMillis(1500L) + assertEquals(1500L, d.millis) + assertEquals(1_500_000L, d.micros) + } + + @Test + fun `fromMillis zero`() { + val d = TimeDuration.fromMillis(0L) + assertEquals(0L, d.millis) + assertEquals(0L, d.micros) + } + + @Test + fun `constructor from Duration`() { + val d = TimeDuration(3.seconds) + assertEquals(3000L, d.millis) + assertEquals(3_000_000L, d.micros) + } + + @Test + fun `constructor from microseconds Duration`() { + val d = TimeDuration(500.microseconds) + assertEquals(500L, d.micros) + assertEquals(0L, d.millis) // 500us < 1ms + } + + // --- Accessors --- + + @Test + fun `micros and millis are consistent`() { + val d = TimeDuration.fromMillis(2345L) + assertEquals(d.micros, d.millis * 1000) + } + + // --- Arithmetic --- + + @Test + fun `plus adds durations`() { + val a = TimeDuration.fromMillis(100L) + val b = TimeDuration.fromMillis(200L) + val result = a + b + assertEquals(300L, result.millis) + } + + @Test + fun `minus subtracts durations`() { + val a = TimeDuration.fromMillis(500L) + val b = TimeDuration.fromMillis(200L) + val result = a - b + assertEquals(300L, result.millis) + } + + @Test + fun `minus can produce negative duration`() { + val a = TimeDuration.fromMillis(100L) + val b = TimeDuration.fromMillis(500L) + val result = a - b + assertTrue(result.micros < 0, "100 - 500 should be negative") + assertEquals(-400L, result.millis) + } + + @Test + fun `plus and minus are inverse`() { + val a = TimeDuration.fromMillis(1000L) + val b = TimeDuration.fromMillis(300L) + assertEquals(a, (a + b) - b) + } + + @Test + fun `plus zero is identity`() { + val a = TimeDuration.fromMillis(42L) + assertEquals(a, a + TimeDuration.fromMillis(0L)) + } + + // --- Comparison --- + + @Test + fun `compareTo orders by duration`() { + val short = TimeDuration.fromMillis(100L) + val long = TimeDuration.fromMillis(200L) + assertTrue(short < long) + assertTrue(long > short) + } + + @Test + fun `compareTo equal durations`() { + val a = TimeDuration.fromMillis(500L) + val b = TimeDuration.fromMillis(500L) + assertEquals(0, a.compareTo(b)) + } + + @Test + fun `compareTo negative vs positive`() { + val neg = TimeDuration((-100).milliseconds) + val pos = TimeDuration(100.milliseconds) + assertTrue(neg < pos) + } + + // --- Formatting --- + + @Test + fun `toString positive duration`() { + val d = TimeDuration.fromMillis(1500L) + val str = d.toString() + assertTrue(str.startsWith("+"), "Positive duration should start with +: $str") + assertTrue(str.contains("1."), "Should show 1 second: $str") + } + + @Test + fun `toString negative duration`() { + val d = TimeDuration((-1500).milliseconds) + val str = d.toString() + assertTrue(str.startsWith("-"), "Negative duration should start with -: $str") + } + + @Test + fun `toString zero`() { + val d = TimeDuration.fromMillis(0L) + val str = d.toString() + assertTrue(str.contains("0.000000"), "Zero should be +0.000000: $str") + } + + @Test + fun `toString has 6 digit microsecond precision`() { + val d = TimeDuration.fromMillis(1234L) + val str = d.toString() + // format: +1.234000 + val frac = str.substringAfter(".") + assertEquals(6, frac.length, "Fraction should be 6 digits: $str") + } + + // --- equals / hashCode --- + + @Test + fun `equal durations from different constructors`() { + val a = TimeDuration.fromMillis(1000L) + val b = TimeDuration(1.seconds) + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TimestampTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TimestampTest.kt new file mode 100644 index 00000000000..df3a626149e --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TimestampTest.kt @@ -0,0 +1,182 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.TimeDuration +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class TimestampTest { + + // --- Factories --- + + @Test + fun `UNIX_EPOCH is at epoch zero`() { + assertEquals(0L, Timestamp.UNIX_EPOCH.microsSinceUnixEpoch) + assertEquals(0L, Timestamp.UNIX_EPOCH.millisSinceUnixEpoch) + } + + @Test + fun `now returns a timestamp after epoch`() { + val now = Timestamp.now() + assertTrue(now.microsSinceUnixEpoch > 0, "now() should be after epoch") + } + + @Test + fun `now returns increasing timestamps`() { + val a = Timestamp.now() + val b = Timestamp.now() + assertTrue(b >= a, "Second now() should be >= first") + } + + @Test + fun `fromMillis creates correct timestamp`() { + val ts = Timestamp.fromMillis(1000L) + assertEquals(1000L, ts.millisSinceUnixEpoch) + assertEquals(1_000_000L, ts.microsSinceUnixEpoch) + } + + @Test + fun `fromMillis zero is epoch`() { + assertEquals(Timestamp.UNIX_EPOCH, Timestamp.fromMillis(0L)) + } + + @Test + fun `fromEpochMicroseconds creates correct timestamp`() { + val ts = Timestamp.fromEpochMicroseconds(1_500_000L) + assertEquals(1_500_000L, ts.microsSinceUnixEpoch) + assertEquals(1500L, ts.millisSinceUnixEpoch) + } + + @Test + fun `fromEpochMicroseconds zero is epoch`() { + assertEquals(Timestamp.UNIX_EPOCH, Timestamp.fromEpochMicroseconds(0L)) + } + + // --- Accessors --- + + @Test + fun `microsSinceUnixEpoch and millisSinceUnixEpoch are consistent`() { + val ts = Timestamp.fromMillis(12345L) + assertEquals(ts.microsSinceUnixEpoch, ts.millisSinceUnixEpoch * 1000) + } + + // --- Arithmetic --- + + @Test + fun `plus TimeDuration adds time`() { + val ts = Timestamp.fromMillis(1000L) + val dur = TimeDuration.fromMillis(500L) + val result = ts + dur + assertEquals(1500L, result.millisSinceUnixEpoch) + } + + @Test + fun `minus TimeDuration subtracts time`() { + val ts = Timestamp.fromMillis(1000L) + val dur = TimeDuration.fromMillis(300L) + val result = ts - dur + assertEquals(700L, result.millisSinceUnixEpoch) + } + + @Test + fun `minus Timestamp returns TimeDuration`() { + val a = Timestamp.fromMillis(1000L) + val b = Timestamp.fromMillis(400L) + val diff = a - b + assertEquals(600L, diff.millis) + } + + @Test + fun `minus Timestamp can be negative`() { + val a = Timestamp.fromMillis(100L) + val b = Timestamp.fromMillis(500L) + val diff = a - b + assertTrue(diff.micros < 0, "Earlier - later should be negative: ${diff.micros}") + } + + @Test + fun `since returns duration between timestamps`() { + val a = Timestamp.fromMillis(1000L) + val b = Timestamp.fromMillis(300L) + val dur = a.since(b) + assertEquals(700L, dur.millis) + } + + @Test + fun `plus and minus are inverse operations`() { + val ts = Timestamp.fromMillis(5000L) + val dur = TimeDuration.fromMillis(1234L) + assertEquals(ts, (ts + dur) - dur) + } + + // --- Comparison --- + + @Test + fun `compareTo orders by time`() { + val early = Timestamp.fromMillis(100L) + val late = Timestamp.fromMillis(200L) + assertTrue(early < late) + assertTrue(late > early) + } + + @Test + fun `compareTo equal timestamps`() { + val a = Timestamp.fromMillis(100L) + val b = Timestamp.fromMillis(100L) + assertEquals(0, a.compareTo(b)) + } + + @Test + fun `UNIX_EPOCH is less than now`() { + assertTrue(Timestamp.UNIX_EPOCH < Timestamp.now()) + } + + // --- Formatting --- + + @Test + fun `toISOString contains Z suffix`() { + val ts = Timestamp.fromMillis(1000L) + val iso = ts.toISOString() + assertTrue(iso.endsWith("Z"), "ISO string should end with Z: $iso") + } + + @Test + fun `toISOString contains T separator`() { + val ts = Timestamp.now() + val iso = ts.toISOString() + assertTrue(iso.contains("T"), "ISO string should contain T: $iso") + } + + @Test + fun `toISOString preserves microsecond precision`() { + val ts = Timestamp.fromEpochMicroseconds(1_000_123_456L) + val iso = ts.toISOString() + // Should have 6-digit microsecond fraction + assertTrue(iso.contains("."), "ISO string should have fractional part: $iso") + val frac = iso.substringAfter(".").removeSuffix("Z") + assertEquals(6, frac.length, "Fraction should be 6 digits: $frac") + } + + @Test + fun `toString equals toISOString`() { + val ts = Timestamp.fromMillis(42000L) + assertEquals(ts.toISOString(), ts.toString()) + } + + @Test + fun `UNIX_EPOCH toISOString is 1970-01-01`() { + val iso = Timestamp.UNIX_EPOCH.toISOString() + assertTrue(iso.startsWith("1970-01-01"), "Epoch should be 1970-01-01: $iso") + } + + // --- equals / hashCode --- + + @Test + fun `equal timestamps from different factories are equal`() { + val a = Timestamp.fromMillis(5000L) + val b = Timestamp.fromEpochMicroseconds(5_000_000L) + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TokenReconnectTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TokenReconnectTest.kt new file mode 100644 index 00000000000..9d0a3ceb653 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TokenReconnectTest.kt @@ -0,0 +1,69 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import kotlinx.coroutines.runBlocking +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotEquals + +class TokenReconnectTest { + + @Test + fun `reconnect with saved token returns same identity`() = runBlocking { + val first = connectToDb() + val savedToken = first.token + val savedIdentity = first.identity + first.conn.disconnect() + + val second = connectToDb(token = savedToken) + assertEquals(savedIdentity, second.identity, "Identity should be the same when reconnecting with saved token") + second.conn.disconnect() + } + + @Test + fun `reconnect with saved token returns same token`() = runBlocking { + val first = connectToDb() + val savedToken = first.token + first.conn.disconnect() + + val second = connectToDb(token = savedToken) + assertEquals(savedToken, second.token, "Token should be the same when reconnecting") + second.conn.disconnect() + } + + @Test + fun `connect without token generates new identity each time`() = runBlocking { + val first = connectToDb() + val firstIdentity = first.identity + first.conn.disconnect() + + val second = connectToDb() + assertNotEquals(firstIdentity, second.identity, "Different anonymous connections should have different identities") + second.conn.disconnect() + } + + @Test + fun `connect without token generates new token each time`() = runBlocking { + val first = connectToDb() + val firstToken = first.token + first.conn.disconnect() + + val second = connectToDb() + assertNotEquals(firstToken, second.token, "Different anonymous connections should have different tokens") + second.conn.disconnect() + } + + @Test + fun `token from first connection works after multiple reconnects`() = runBlocking { + val first = connectToDb() + val savedToken = first.token + val savedIdentity = first.identity + first.conn.disconnect() + + // Reconnect 3 times with same token + for (i in 1..3) { + val client = connectToDb(token = savedToken) + assertEquals(savedIdentity, client.identity, "Identity should match on reconnect #$i") + client.conn.disconnect() + } + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TypeSafeQueryTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TypeSafeQueryTest.kt new file mode 100644 index 00000000000..19ecbf91be9 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/TypeSafeQueryTest.kt @@ -0,0 +1,151 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SqlLit +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.QueryBuilder +import module_bindings.addQuery +import module_bindings.db +import module_bindings.reducers +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class TypeSafeQueryTest { + + @Test + fun `where with eq generates correct SQL and subscribes`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + // Subscribe using type-safe query: user where online = true + client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .addQuery { qb -> qb.user().where { c -> c.online.eq(SqlLit.bool(true)) } } + .subscribe() + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + + // We should see at least ourselves (we're online) + val users = client.conn.db.user.all() + assertTrue(users.isNotEmpty(), "Should see online users") + assertTrue(users.all { it.online }, "All users should be online with this filter") + + client.conn.disconnect() + } + + @Test + fun `filter is alias for where`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .addQuery { qb -> qb.user().filter { c -> c.online.eq(SqlLit.bool(true)) } } + .subscribe() + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + val users = client.conn.db.user.all() + assertTrue(users.isNotEmpty(), "Filter should work like where") + + client.conn.disconnect() + } + + @Test + fun `neq comparison works`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + // Subscribe to users where online != false (i.e. online users) + client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .addQuery { qb -> qb.user().where { c -> c.online.neq(SqlLit.bool(false)) } } + .subscribe() + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + val users = client.conn.db.user.all() + assertTrue(users.all { it.online }, "neq(false) should return only online users") + + client.conn.disconnect() + } + + @Test + fun `boolean combinators and-or-not`() = runBlocking { + val client = connectToDb() + + // First subscribe to everything so we have data + client.subscribeAll() + + // Add a note so we can test with note table + val noteDone = CompletableDeferred() + client.conn.reducers.onAddNote { ctx, _, _ -> + if (ctx.callerIdentity == client.identity) noteDone.complete(Unit) + } + client.conn.reducers.addNote("bool-test-content", "bool-test") + withTimeout(DEFAULT_TIMEOUT_MS) { noteDone.await() } + + // Test that the query DSL generates valid SQL with and/or/not + val qb = QueryBuilder() + val query = qb.note().where { c -> + c.tag.eq(SqlLit.string("bool-test")) + .and(c.content.eq(SqlLit.string("bool-test-content"))) + } + val sql = query.toSql() + assertTrue(sql.contains("AND"), "SQL should contain AND: $sql") + + val queryOr = qb.note().where { c -> + c.tag.eq(SqlLit.string("a")).or(c.tag.eq(SqlLit.string("b"))) + } + assertTrue(queryOr.toSql().contains("OR"), "SQL should contain OR") + + val queryNot = qb.user().where { c -> + c.online.eq(SqlLit.bool(true)).not() + } + assertTrue(queryNot.toSql().contains("NOT"), "SQL should contain NOT") + + client.cleanup() + } + + @Test + fun `SqlLit creates typed literals`() = runBlocking { + // Test various SqlLit factory methods produce valid SQL strings + assertTrue(SqlLit.string("hello").sql.contains("hello")) + assertEquals(SqlLit.bool(true).sql, "TRUE") + assertEquals(SqlLit.bool(false).sql, "FALSE") + assertEquals(SqlLit.int(42).sql, "42") + assertEquals(SqlLit.ulong(100UL).sql, "100") + assertEquals(SqlLit.long(999L).sql, "999") + assertEquals(SqlLit.float(1.5f).sql, "1.5") + assertEquals(SqlLit.double(2.5).sql, "2.5") + } + + @Test + fun `NullableCol generates valid SQL`() = runBlocking { + // User.name is a NullableCol + // Test that NullableCol methods produce correct SQL strings + val qb = QueryBuilder() + + // NullableCol.eq with SqlLiteral + val eqSql = qb.user().where { c -> c.name.eq(SqlLit.string("alice")) }.toSql() + assertTrue(eqSql.contains("\"name\"") && eqSql.contains("alice"), "eq SQL: $eqSql") + + // NullableCol.neq with SqlLiteral + val neqSql = qb.user().where { c -> c.name.neq(SqlLit.string("bob")) }.toSql() + assertTrue(neqSql.contains("<>"), "neq SQL: $neqSql") + + // NullableCol.eq with another NullableCol (self-reference — valid SQL structure) + val colEqSql = qb.user().where { c -> c.name.eq(c.name) }.toSql() + assertTrue(colEqSql.contains("\"name\" = "), "col-eq SQL: $colEqSql") + + // NullableCol comparison operators + val ltSql = qb.user().where { c -> c.name.lt(SqlLit.string("z")) }.toSql() + assertTrue(ltSql.contains("<"), "lt SQL: $ltSql") + + val gteSql = qb.user().where { c -> c.name.gte(SqlLit.string("a")) }.toSql() + assertTrue(gteSql.contains(">="), "gte SQL: $gteSql") + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/UnsubscribeFlagsTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/UnsubscribeFlagsTest.kt new file mode 100644 index 00000000000..eb663bc7ebb --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/UnsubscribeFlagsTest.kt @@ -0,0 +1,145 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SubscriptionState +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class UnsubscribeFlagsTest { + + @Test + fun `unsubscribeThen transitions to ENDED`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + val handle = client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM note") + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + assertTrue(handle.isActive) + + val unsubDone = CompletableDeferred() + handle.unsubscribeThen { _ -> unsubDone.complete(Unit) } + withTimeout(DEFAULT_TIMEOUT_MS) { unsubDone.await() } + + assertEquals(SubscriptionState.ENDED, handle.state) + + client.conn.disconnect() + } + + @Test + fun `unsubscribeThen callback receives context`() = runBlocking { + val client = connectToDb() + val applied = CompletableDeferred() + + val handle = client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM note") + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + + val gotContext = CompletableDeferred() + handle.unsubscribeThen { ctx -> + gotContext.complete(ctx) + } + + val result = withTimeout(DEFAULT_TIMEOUT_MS) { gotContext.await() } + assertNotNull(result, "unsubscribeThen callback should receive non-null context") + + client.conn.disconnect() + } + + @Test + fun `unsubscribe completes without error`() = runBlocking { + val client = connectToDb() + + val applied = CompletableDeferred() + val handle = client.conn.subscriptionBuilder() + .onApplied { _ -> applied.complete(Unit) } + .onError { _, err -> applied.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM user") + + withTimeout(DEFAULT_TIMEOUT_MS) { applied.await() } + assertTrue(handle.isActive, "Should be active after applied") + + val unsubDone = CompletableDeferred() + handle.unsubscribeThen { _ -> unsubDone.complete(Unit) } + withTimeout(DEFAULT_TIMEOUT_MS) { unsubDone.await() } + + // After unsubscribeThen callback fires, the handle should be ENDED + assertTrue(handle.isEnded, "Should be ended after unsubscribe completes") + + client.conn.disconnect() + } + + @Test + fun `multiple subscriptions can be independently unsubscribed`() = runBlocking { + val client = connectToDb() + + val applied1 = CompletableDeferred() + val handle1 = client.conn.subscriptionBuilder() + .onApplied { _ -> applied1.complete(Unit) } + .onError { _, err -> applied1.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM user") + + val applied2 = CompletableDeferred() + val handle2 = client.conn.subscriptionBuilder() + .onApplied { _ -> applied2.complete(Unit) } + .onError { _, err -> applied2.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM note") + + withTimeout(DEFAULT_TIMEOUT_MS) { applied1.await() } + withTimeout(DEFAULT_TIMEOUT_MS) { applied2.await() } + + // Unsubscribe only handle1 + val unsub1 = CompletableDeferred() + handle1.unsubscribeThen { _ -> unsub1.complete(Unit) } + withTimeout(DEFAULT_TIMEOUT_MS) { unsub1.await() } + + assertEquals(SubscriptionState.ENDED, handle1.state, "handle1 should be ENDED") + assertEquals(SubscriptionState.ACTIVE, handle2.state, "handle2 should still be ACTIVE") + + client.conn.disconnect() + } + + @Test + fun `unsubscribe then re-subscribe works`() = runBlocking { + val client = connectToDb() + + // Subscribe + val applied1 = CompletableDeferred() + val handle1 = client.conn.subscriptionBuilder() + .onApplied { _ -> applied1.complete(Unit) } + .onError { _, err -> applied1.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM user") + + withTimeout(DEFAULT_TIMEOUT_MS) { applied1.await() } + + // Unsubscribe + val unsub = CompletableDeferred() + handle1.unsubscribeThen { _ -> unsub.complete(Unit) } + withTimeout(DEFAULT_TIMEOUT_MS) { unsub.await() } + assertTrue(handle1.isEnded) + + // Re-subscribe + val applied2 = CompletableDeferred() + val handle2 = client.conn.subscriptionBuilder() + .onApplied { _ -> applied2.complete(Unit) } + .onError { _, err -> applied2.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM user") + + withTimeout(DEFAULT_TIMEOUT_MS) { applied2.await() } + assertTrue(handle2.isActive, "Re-subscribed handle should be active") + assertNotEquals(handle1.querySetId, handle2.querySetId, "New subscription should get new querySetId") + + client.conn.disconnect() + } +} diff --git a/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/WithCallbackDispatcherTest.kt b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/WithCallbackDispatcherTest.kt new file mode 100644 index 00000000000..899fa7051d9 --- /dev/null +++ b/sdks/kotlin/integration-tests/src/test/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/integration/WithCallbackDispatcherTest.kt @@ -0,0 +1,115 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.integration + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.asCoroutineDispatcher +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import module_bindings.reducers +import module_bindings.withModuleBindings +import java.util.concurrent.Executors +import kotlin.test.Test +import kotlin.test.assertTrue + +class WithCallbackDispatcherTest { + + private fun createNamedDispatcher(name: String): Pair { + val executor = Executors.newSingleThreadExecutor { r -> Thread(r, name) } + return executor.asCoroutineDispatcher() to executor + } + + @Test + fun `onConnect callback runs on custom dispatcher`() = runBlocking { + val (dispatcher, executor) = createNamedDispatcher("custom-cb-thread") + + val threadName = CompletableDeferred() + val conn = DbConnection.Builder() + .withHttpClient(createTestHttpClient()) + .withUri(HOST) + .withDatabaseName(DB_NAME) + .withModuleBindings() + .withCallbackDispatcher(dispatcher) + .onConnect { _, _, _ -> threadName.complete(Thread.currentThread().name) } + .onConnectError { _, e -> threadName.completeExceptionally(e) } + .build() + + val name = withTimeout(DEFAULT_TIMEOUT_MS) { threadName.await() } + assertTrue(name.startsWith("custom-cb-thread"), "onConnect should run on custom thread, got: $name") + + conn.disconnect() + dispatcher.close() + executor.shutdown() + } + + @Test + fun `subscription onApplied callback runs on custom dispatcher`() = runBlocking { + val (dispatcher, executor) = createNamedDispatcher("sub-cb-thread") + + val connected = CompletableDeferred() + val conn = DbConnection.Builder() + .withHttpClient(createTestHttpClient()) + .withUri(HOST) + .withDatabaseName(DB_NAME) + .withModuleBindings() + .withCallbackDispatcher(dispatcher) + .onConnect { _, _, _ -> connected.complete(Unit) } + .onConnectError { _, e -> connected.completeExceptionally(e) } + .build() + + withTimeout(DEFAULT_TIMEOUT_MS) { connected.await() } + + val threadName = CompletableDeferred() + conn.subscriptionBuilder() + .onApplied { _ -> threadName.complete(Thread.currentThread().name) } + .onError { _, err -> threadName.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM user") + + val name = withTimeout(DEFAULT_TIMEOUT_MS) { threadName.await() } + assertTrue(name.startsWith("sub-cb-thread"), "onApplied should run on custom thread, got: $name") + + conn.disconnect() + dispatcher.close() + executor.shutdown() + } + + @Test + fun `reducer callback runs on custom dispatcher`() = runBlocking { + val (dispatcher, executor) = createNamedDispatcher("reducer-cb-thread") + + val connected = CompletableDeferred() + val conn = DbConnection.Builder() + .withHttpClient(createTestHttpClient()) + .withUri(HOST) + .withDatabaseName(DB_NAME) + .withModuleBindings() + .withCallbackDispatcher(dispatcher) + .onConnect { _, _, _ -> + connected.complete(Unit) + } + .onConnectError { _, e -> connected.completeExceptionally(e) } + .build() + + withTimeout(DEFAULT_TIMEOUT_MS) { connected.await() } + + // Subscribe first so reducer callbacks can fire + val subApplied = CompletableDeferred() + conn.subscriptionBuilder() + .onApplied { _ -> subApplied.complete(Unit) } + .onError { _, err -> subApplied.completeExceptionally(RuntimeException("$err")) } + .subscribe("SELECT * FROM user") + withTimeout(DEFAULT_TIMEOUT_MS) { subApplied.await() } + + val threadName = CompletableDeferred() + conn.reducers.onSetName { _, _ -> + threadName.complete(Thread.currentThread().name) + } + conn.reducers.setName("dispatcher-test-${System.nanoTime()}") + + val name = withTimeout(DEFAULT_TIMEOUT_MS) { threadName.await() } + assertTrue(name.startsWith("reducer-cb-thread"), "reducer callback should run on custom thread, got: $name") + + conn.disconnect() + dispatcher.close() + executor.shutdown() + } +} diff --git a/sdks/kotlin/settings.gradle.kts b/sdks/kotlin/settings.gradle.kts new file mode 100644 index 00000000000..056566964b4 --- /dev/null +++ b/sdks/kotlin/settings.gradle.kts @@ -0,0 +1,40 @@ +@file:Suppress("UnstableApiUsage") + +rootProject.name = "SpacetimedbKotlinSdk" +enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS") + +pluginManagement { + includeBuild("spacetimedb-gradle-plugin") + repositories { + google { + mavenContent { + includeGroupAndSubgroups("androidx") + includeGroupAndSubgroups("com.android") + includeGroupAndSubgroups("com.google") + } + } + mavenCentral() + gradlePluginPortal() + } +} + +dependencyResolutionManagement { + repositories { + google { + mavenContent { + includeGroupAndSubgroups("androidx") + includeGroupAndSubgroups("com.android") + includeGroupAndSubgroups("com.google") + } + } + mavenCentral() + } +} + +plugins { + id("org.gradle.toolchains.foojay-resolver-convention") version "1.0.0" +} + +include(":spacetimedb-sdk") +include(":integration-tests") +include(":codegen-tests") diff --git a/sdks/kotlin/spacetimedb-gradle-plugin/README.md b/sdks/kotlin/spacetimedb-gradle-plugin/README.md new file mode 100644 index 00000000000..9e641c85b3f --- /dev/null +++ b/sdks/kotlin/spacetimedb-gradle-plugin/README.md @@ -0,0 +1,72 @@ +# SpacetimeDB Gradle Plugin + +Gradle plugin for SpacetimeDB Kotlin projects. Automatically generates Kotlin client bindings and build-time configuration from your SpacetimeDB module. + +## Setup + +```kotlin +// settings.gradle.kts +pluginManagement { + includeBuild("/path/to/SpacetimeDB/sdks/kotlin/spacetimedb-gradle-plugin") +} + +// build.gradle.kts +plugins { + id("com.clockworklabs.spacetimedb") +} +``` + +## Configuration + +```kotlin +spacetimedb { + // Path to the SpacetimeDB module directory. + // Default: read from "module-path" in spacetime.json, falls back to "spacetimedb/" + modulePath.set(file("server")) + + // Path to spacetimedb-cli binary (default: resolved from PATH) + cli.set(file("/path/to/spacetimedb-cli")) + + // Config file paths (default: spacetime.local.json and spacetime.json in root project) + localConfig.set(file("spacetime.local.json")) + mainConfig.set(file("spacetime.json")) +} +``` + +## Generated Files + +### Bindings (`build/generated/spacetimedb/bindings/`) + +Kotlin data classes, table handles, reducer stubs, and query builders generated from your module's schema via `spacetimedb-cli generate`. + +### SpacetimeConfig (`build/generated/spacetimedb/config/SpacetimeConfig.kt`) + +Build-time constants extracted from `spacetime.local.json` / `spacetime.json`: + +```kotlin +package module_bindings + +object SpacetimeConfig { + const val DATABASE_NAME: String = "my-app" // from "database" field + const val MODULE_PATH: String = "./spacetimedb" // from "module-path" field +} +``` + +Fields are only included when present in the config. `spacetime.local.json` takes priority over `spacetime.json`. + +## Tasks + +| Task | Description | +|------|-------------| +| `generateSpacetimeBindings` | Runs `spacetimedb-cli generate` to produce Kotlin bindings. Wired into `compileKotlin`. | +| `generateSpacetimeConfig` | Generates `SpacetimeConfig.kt` from project config. Wired into `compileKotlin`. | +| `cleanSpacetimeModule` | Deletes `spacetimedb/target/` (Rust build cache). Runs as part of `gradle clean`. | + +## Notes + +- **`gradle clean` triggers a full Rust recompilation** on the next build, since `cleanSpacetimeModule` deletes the Cargo `target/` directory. To clean only Kotlin artifacts: + ``` + gradle clean -x cleanSpacetimeModule + ``` +- The plugin detects module source changes and re-generates bindings automatically. +- Both `org.jetbrains.kotlin.jvm` and `org.jetbrains.kotlin.multiplatform` are supported. diff --git a/sdks/kotlin/spacetimedb-gradle-plugin/build.gradle.kts b/sdks/kotlin/spacetimedb-gradle-plugin/build.gradle.kts new file mode 100644 index 00000000000..0ac6ec702dd --- /dev/null +++ b/sdks/kotlin/spacetimedb-gradle-plugin/build.gradle.kts @@ -0,0 +1,24 @@ +plugins { + alias(libs.plugins.kotlinJvm) + `java-gradle-plugin` +} + +group = "com.clockworklabs" +version = "0.1.0" + +kotlin { + jvmToolchain(21) +} + +dependencies { + compileOnly("org.jetbrains.kotlin:kotlin-gradle-plugin:${libs.versions.kotlin.get()}") +} + +gradlePlugin { + plugins { + create("spacetimedb") { + id = "com.clockworklabs.spacetimedb" + implementationClass = "com.clockworklabs.spacetimedb.SpacetimeDbPlugin" + } + } +} diff --git a/sdks/kotlin/spacetimedb-gradle-plugin/settings.gradle.kts b/sdks/kotlin/spacetimedb-gradle-plugin/settings.gradle.kts new file mode 100644 index 00000000000..1a4794687b2 --- /dev/null +++ b/sdks/kotlin/spacetimedb-gradle-plugin/settings.gradle.kts @@ -0,0 +1,11 @@ +dependencyResolutionManagement { + repositories { + mavenCentral() + gradlePluginPortal() + } + versionCatalogs { + create("libs") { + from(files("../gradle/libs.versions.toml")) + } + } +} diff --git a/sdks/kotlin/spacetimedb-gradle-plugin/src/main/kotlin/com/clockworklabs/spacetimedb/GenerateBindingsTask.kt b/sdks/kotlin/spacetimedb-gradle-plugin/src/main/kotlin/com/clockworklabs/spacetimedb/GenerateBindingsTask.kt new file mode 100644 index 00000000000..7391bed1a9c --- /dev/null +++ b/sdks/kotlin/spacetimedb-gradle-plugin/src/main/kotlin/com/clockworklabs/spacetimedb/GenerateBindingsTask.kt @@ -0,0 +1,82 @@ +package com.clockworklabs.spacetimedb + +import org.gradle.api.DefaultTask +import org.gradle.api.file.ConfigurableFileCollection +import org.gradle.api.file.DirectoryProperty +import org.gradle.api.file.RegularFileProperty +import org.gradle.api.tasks.InputFile +import org.gradle.api.tasks.InputFiles +import org.gradle.api.tasks.Internal +import org.gradle.api.tasks.Optional +import org.gradle.api.tasks.OutputDirectory +import org.gradle.api.tasks.PathSensitive +import org.gradle.api.tasks.PathSensitivity +import org.gradle.api.tasks.TaskAction +import org.gradle.process.ExecOperations +import javax.inject.Inject + +abstract class GenerateBindingsTask @Inject constructor( + private val execOps: ExecOperations +) : DefaultTask() { + + @get:InputFile + @get:Optional + @get:PathSensitive(PathSensitivity.ABSOLUTE) + abstract val cli: RegularFileProperty + + @get:Internal + abstract val modulePath: DirectoryProperty + + @get:InputFiles + @get:PathSensitive(PathSensitivity.RELATIVE) + abstract val moduleSourceFiles: ConfigurableFileCollection + + @get:OutputDirectory + abstract val outputDir: DirectoryProperty + + init { + group = "spacetimedb" + description = "Generate SpacetimeDB Kotlin client bindings" + } + + @TaskAction + fun generate() { + val moduleDir = modulePath.get().asFile + require(moduleDir.isDirectory) { + "SpacetimeDB module directory not found at '${moduleDir.absolutePath}'. " + + "Set the correct path via: spacetimedb { modulePath.set(file(\"/path/to/module\")) }" + } + + val outDir = outputDir.get().asFile + if (outDir.isDirectory) { + outDir.listFiles()?.forEach { it.deleteRecursively() } + } + outDir.mkdirs() + + val cliPath = if (cli.isPresent) { + cli.get().asFile.absolutePath + } else { + "spacetimedb-cli" + } + + try { + execOps.exec { spec -> + spec.commandLine( + cliPath, "generate", + "--lang", "kotlin", + "--out-dir", outDir.absolutePath, + "--module-path", modulePath.get().asFile.absolutePath, + ) + } + } catch (e: Exception) { + if (!cli.isPresent) { + logger.warn( + "spacetimedb-cli not found — Kotlin bindings will not be auto-generated. " + + "Install from https://spacetimedb.com or set: spacetimedb { cli.set(file(\"/path/to/spacetimedb-cli\")) }" + ) + return + } + throw e + } + } +} diff --git a/sdks/kotlin/spacetimedb-gradle-plugin/src/main/kotlin/com/clockworklabs/spacetimedb/GenerateConfigTask.kt b/sdks/kotlin/spacetimedb-gradle-plugin/src/main/kotlin/com/clockworklabs/spacetimedb/GenerateConfigTask.kt new file mode 100644 index 00000000000..7bebaedbc1e --- /dev/null +++ b/sdks/kotlin/spacetimedb-gradle-plugin/src/main/kotlin/com/clockworklabs/spacetimedb/GenerateConfigTask.kt @@ -0,0 +1,82 @@ +package com.clockworklabs.spacetimedb + +import org.gradle.api.DefaultTask +import org.gradle.api.file.DirectoryProperty +import org.gradle.api.file.RegularFileProperty +import org.gradle.api.tasks.InputFile +import org.gradle.api.tasks.Optional +import org.gradle.api.tasks.OutputDirectory +import org.gradle.api.tasks.PathSensitive +import org.gradle.api.tasks.PathSensitivity +import org.gradle.api.tasks.TaskAction + +/** + * Reads configuration from spacetime.local.json (or spacetime.json) + * and generates a SpacetimeConfig.kt with build-time constants. + */ +abstract class GenerateConfigTask : DefaultTask() { + + @get:InputFile + @get:Optional + @get:PathSensitive(PathSensitivity.RELATIVE) + abstract val localConfig: RegularFileProperty + + @get:InputFile + @get:Optional + @get:PathSensitive(PathSensitivity.RELATIVE) + abstract val mainConfig: RegularFileProperty + + @get:OutputDirectory + abstract val outputDir: DirectoryProperty + + init { + group = "spacetimedb" + description = "Generate SpacetimeConfig.kt from SpacetimeDB project config" + } + + @TaskAction + fun generate() { + val localJson = readJson(localConfig) + val mainJson = readJson(mainConfig) + + fun field(key: String): String? = localJson?.get(key) ?: mainJson?.get(key) + + val dbName = field("database") + val modulePath = field("module-path") + + if (dbName == null && modulePath == null) { + logger.warn("No config found in spacetime.local.json or spacetime.json — skipping SpacetimeConfig generation") + return + } + + val outDir = outputDir.get().asFile + outDir.mkdirs() + + val code = buildString { + appendLine("// THIS FILE IS AUTOMATICALLY GENERATED BY THE SPACETIMEDB GRADLE PLUGIN.") + appendLine("// DO NOT EDIT — changes will be overwritten on next build.") + appendLine() + appendLine("package module_bindings") + appendLine() + appendLine("object SpacetimeConfig {") + if (dbName != null) { + appendLine(" const val DATABASE_NAME: String = \"$dbName\"") + } + if (modulePath != null) { + appendLine(" const val MODULE_PATH: String = \"$modulePath\"") + } + appendLine("}") + appendLine() + } + + outDir.resolve("SpacetimeConfig.kt").writeText(code) + } + + private fun readJson(file: RegularFileProperty): Map? { + if (!file.isPresent) return null + val f = file.get().asFile + if (!f.isFile) return null + @Suppress("UNCHECKED_CAST") + return groovy.json.JsonSlurper().parseText(f.readText()) as? Map + } +} diff --git a/sdks/kotlin/spacetimedb-gradle-plugin/src/main/kotlin/com/clockworklabs/spacetimedb/SpacetimeDbExtension.kt b/sdks/kotlin/spacetimedb-gradle-plugin/src/main/kotlin/com/clockworklabs/spacetimedb/SpacetimeDbExtension.kt new file mode 100644 index 00000000000..4951c00741b --- /dev/null +++ b/sdks/kotlin/spacetimedb-gradle-plugin/src/main/kotlin/com/clockworklabs/spacetimedb/SpacetimeDbExtension.kt @@ -0,0 +1,18 @@ +package com.clockworklabs.spacetimedb + +import org.gradle.api.file.DirectoryProperty +import org.gradle.api.file.RegularFileProperty + +abstract class SpacetimeDbExtension { + /** Path to the spacetimedb-cli binary. Defaults to "spacetimedb-cli" on the PATH. */ + abstract val cli: RegularFileProperty + + /** Path to the SpacetimeDB module directory. Defaults to "spacetimedb/" in the root project. */ + abstract val modulePath: DirectoryProperty + + /** Path to spacetime.local.json. Defaults to "spacetime.local.json" in the root project. */ + abstract val localConfig: RegularFileProperty + + /** Path to spacetime.json. Defaults to "spacetime.json" in the root project. */ + abstract val mainConfig: RegularFileProperty +} diff --git a/sdks/kotlin/spacetimedb-gradle-plugin/src/main/kotlin/com/clockworklabs/spacetimedb/SpacetimeDbPlugin.kt b/sdks/kotlin/spacetimedb-gradle-plugin/src/main/kotlin/com/clockworklabs/spacetimedb/SpacetimeDbPlugin.kt new file mode 100644 index 00000000000..aa5b9fd4342 --- /dev/null +++ b/sdks/kotlin/spacetimedb-gradle-plugin/src/main/kotlin/com/clockworklabs/spacetimedb/SpacetimeDbPlugin.kt @@ -0,0 +1,88 @@ +package com.clockworklabs.spacetimedb + +import org.gradle.api.Plugin +import org.gradle.api.Project +import org.gradle.api.tasks.Delete +import org.gradle.api.tasks.SourceSetContainer +import org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension + +class SpacetimeDbPlugin : Plugin { + + override fun apply(project: Project) { + val ext = project.extensions.create("spacetimedb", SpacetimeDbExtension::class.java) + + val rootDir = project.rootProject.layout.projectDirectory + ext.localConfig.convention(rootDir.file("spacetime.local.json")) + ext.mainConfig.convention(rootDir.file("spacetime.json")) + + // Derive modulePath default from spacetime.json's "module-path", fall back to "spacetimedb" + val configModulePath = readConfigField(rootDir.asFile, "module-path") + ext.modulePath.convention(rootDir.dir(configModulePath ?: "spacetimedb")) + + val bindingsDir = project.layout.buildDirectory.dir("generated/spacetimedb/bindings") + val configDir = project.layout.buildDirectory.dir("generated/spacetimedb/config") + + // Clean the Rust target directory when running `gradle clean` + project.tasks.register("cleanSpacetimeModule", Delete::class.java) { + it.group = "spacetimedb" + it.description = "Clean SpacetimeDB module build artifacts" + it.delete(ext.modulePath.map { dir -> dir.dir("target") }) + } + project.plugins.withType(org.gradle.api.plugins.BasePlugin::class.java) { + project.tasks.named("clean") { it.dependsOn("cleanSpacetimeModule") } + } + + val generateTask = project.tasks.register("generateSpacetimeBindings", GenerateBindingsTask::class.java) { + it.cli.set(ext.cli) + it.modulePath.set(ext.modulePath) + it.moduleSourceFiles.from(ext.modulePath.map { dir -> + project.fileTree(dir) { tree -> tree.exclude("target") } + }) + it.outputDir.set(bindingsDir) + } + + val configTask = project.tasks.register("generateSpacetimeConfig", GenerateConfigTask::class.java) { + val localFile = ext.localConfig + val mainFile = ext.mainConfig + if (localFile.isPresent && localFile.get().asFile.exists()) it.localConfig.set(localFile) + if (mainFile.isPresent && mainFile.get().asFile.exists()) it.mainConfig.set(mainFile) + it.outputDir.set(configDir) + } + + // Wire generated sources into Kotlin compilation + project.pluginManager.withPlugin("org.jetbrains.kotlin.jvm") { + val sourceSets = project.extensions.getByType(SourceSetContainer::class.java) + sourceSets.getByName("main").java.srcDir(bindingsDir) + sourceSets.getByName("main").java.srcDir(configDir) + + project.tasks.named("compileKotlin") { + it.dependsOn(generateTask) + it.dependsOn(configTask) + } + } + + project.pluginManager.withPlugin("org.jetbrains.kotlin.multiplatform") { + val kmpSourceSets = project.extensions.getByType(KotlinMultiplatformExtension::class.java).sourceSets + kmpSourceSets.getByName("commonMain").kotlin.srcDir(bindingsDir) + kmpSourceSets.getByName("commonMain").kotlin.srcDir(configDir) + + project.tasks.withType(org.jetbrains.kotlin.gradle.tasks.AbstractKotlinCompileTool::class.java).configureEach { + it.dependsOn(generateTask) + it.dependsOn(configTask) + } + } + } + + /** Read a field from spacetime.local.json or spacetime.json in the given directory. */ + private fun readConfigField(dir: java.io.File, field: String): String? { + for (name in listOf("spacetime.local.json", "spacetime.json")) { + val file = dir.resolve(name) + if (file.isFile) { + val parsed = groovy.json.JsonSlurper().parseText(file.readText()) + val value = (parsed as? Map<*, *>)?.get(field) as? String + if (value != null) return value + } + } + return null + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/build.gradle.kts b/sdks/kotlin/spacetimedb-sdk/build.gradle.kts new file mode 100644 index 00000000000..7d26ab4636d --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/build.gradle.kts @@ -0,0 +1,66 @@ +plugins { + alias(libs.plugins.kotlinMultiplatform) + alias(libs.plugins.androidKotlinMultiplatformLibrary) +} + +group = "com.clockworklabs" +version = "0.1.0" + +kotlin { + explicitApi() + + android { + compileSdk = libs.versions.android.compileSdk.get().toInt() + minSdk = libs.versions.android.minSdk.get().toInt() + namespace = "com.clockworklabs.spacetimedb_kotlin_sdk.shared_client" + } + + listOf( + iosX64(), + iosArm64(), + iosSimulatorArm64() + ).forEach { iosTarget -> + iosTarget.binaries.framework { + baseName = "SpacetimeDBSdk" + isStatic = true + } + } + + jvm() + + sourceSets { + commonMain.dependencies { + implementation(libs.kotlinx.collections.immutable) + implementation(libs.kotlinx.atomicfu) + + implementation(libs.ktor.client.core) + implementation(libs.ktor.client.websockets) + } + + jvmMain.dependencies { + implementation(libs.brotli.dec) + } + + androidMain.dependencies { + implementation(libs.brotli.dec) + } + + commonTest.dependencies { + implementation(libs.kotlin.test) + implementation(libs.kotlinx.coroutines.test) + } + + jvmTest.dependencies { + implementation(libs.ktor.client.okhttp) + } + + all { + languageSettings { + optIn("kotlin.uuid.ExperimentalUuidApi") + optIn("com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi") + } + } + + compilerOptions.freeCompilerArgs.add("-Xexpect-actual-classes") + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/androidMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/Compression.android.kt b/sdks/kotlin/spacetimedb-sdk/src/androidMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/Compression.android.kt new file mode 100644 index 00000000000..dcccafaebb3 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/androidMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/Compression.android.kt @@ -0,0 +1,33 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.CompressionMode +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.util.zip.GZIPInputStream +import org.brotli.dec.BrotliInputStream + +internal actual fun decompressMessage(data: ByteArray): DecompressedPayload { + require(data.isNotEmpty()) { "Empty message" } + + return when (val tag = data[0]) { + Compression.NONE -> DecompressedPayload(data, offset = 1) + Compression.BROTLI -> { + val input = BrotliInputStream(ByteArrayInputStream(data, 1, data.size - 1)) + val output = ByteArrayOutputStream() + input.use { it.copyTo(output) } + DecompressedPayload(output.toByteArray()) + } + Compression.GZIP -> { + val input = GZIPInputStream(ByteArrayInputStream(data, 1, data.size - 1)) + val output = ByteArrayOutputStream() + input.use { it.copyTo(output) } + DecompressedPayload(output.toByteArray()) + } + else -> error("Unknown compression tag: $tag") + } +} + +internal actual val defaultCompressionMode: CompressionMode = CompressionMode.GZIP + +internal actual val availableCompressionModes: Set = + setOf(CompressionMode.NONE, CompressionMode.BROTLI, CompressionMode.GZIP) diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BigInteger.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BigInteger.kt new file mode 100644 index 00000000000..ece41321192 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BigInteger.kt @@ -0,0 +1,485 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +/** + * Sign of a BigInteger magnitude used with [BigInteger.fromByteArray]. + */ +public enum class Sign { POSITIVE, NEGATIVE, ZERO } + +/** + * A fixed-width big integer backed by a canonical little-endian two's complement [ByteArray]. + * + * Designed for fast construction from BSATN wire bytes (which are already LE), + * avoiding the allocation overhead of arbitrary-precision libraries. + */ +public class BigInteger private constructor( + // Canonical LE two's complement bytes. Always at least 1 byte. + // Canonical means no redundant sign-extension bytes at the high end. + internal val leBytes: ByteArray +) : Comparable { + + /** Constructs a BigInteger from a [Long] value. */ + public constructor(value: Long) : this(longToLeBytes(value)) + + /** Constructs a BigInteger from an [Int] value. */ + public constructor(value: Int) : this(value.toLong()) + + // ---- Companion: constants and factories ---- + + public companion object { + private val HEX_CHARS = "0123456789abcdef".toCharArray() + + /** The BigInteger constant zero. */ + public val ZERO: BigInteger = BigInteger(byteArrayOf(0)) + + /** The BigInteger constant one. */ + public val ONE: BigInteger = BigInteger(1L) + + /** The BigInteger constant two. */ + public val TWO: BigInteger = BigInteger(2L) + + /** The BigInteger constant ten. */ + public val TEN: BigInteger = BigInteger(10L) + + /** + * Parses a string representation of a BigInteger in the given [radix]. + * Supports radix 10 (decimal) and 16 (hexadecimal). Negative values use a leading '-'. + */ + public fun parseString(value: String, radix: Int = 10): BigInteger { + require(value.isNotEmpty()) { "Empty string" } + return when (radix) { + 10 -> parseDecimal(value) + 16 -> parseHex(value) + else -> throw IllegalArgumentException("Unsupported radix: $radix") + } + } + + /** Creates a BigInteger from an unsigned [ULong] value. */ + public fun fromULong(value: ULong): BigInteger { + if (value == 0UL) return ZERO + val bytes = ByteArray(8) + var v = value + for (i in 0 until 8) { + bytes[i] = (v and 0xFFu).toByte() + v = v shr 8 + } + // If bit 63 is set, the byte would look negative in two's complement; add sign byte + val leBytes = if (bytes[7].toInt() and 0x80 != 0) { + bytes.copyOf(9) // extra byte is 0x00 + } else { + bytes + } + return BigInteger(canonicalize(leBytes)) + } + + /** + * Creates a BigInteger from a big-endian unsigned magnitude byte array and a [sign]. + * This matches the ionspin `BigInteger.fromByteArray(bytes, sign)` contract. + */ + public fun fromByteArray(bytes: ByteArray, sign: Sign): BigInteger { + if (sign == Sign.ZERO || bytes.all { it == 0.toByte() }) return ZERO + + // Reverse BE magnitude to LE + val le = bytes.reversedArray() + + // Ensure non-negative two's complement (add 0x00 sign byte if high bit set) + val positive = if (le.last().toInt() and 0x80 != 0) { + le.copyOf(le.size + 1) + } else { + le + } + + val canonical = canonicalize(positive) + return if (sign == Sign.NEGATIVE) { + BigInteger(canonical).unaryMinus() + } else { + BigInteger(canonical) + } + } + + /** + * Constructs a BigInteger from LE two's complement bytes (signed interpretation). + * Used by BsatnReader for signed integer types (I128, I256). + */ + internal fun fromLeBytes(source: ByteArray, offset: Int, length: Int): BigInteger { + val bytes = source.copyOfRange(offset, offset + length) + return BigInteger(canonicalize(bytes)) + } + + /** + * Constructs a non-negative BigInteger from LE bytes (unsigned interpretation). + * If the high bit is set, a zero byte is appended to force a positive two's complement value. + * Used by BsatnReader for unsigned integer types (U128, U256). + */ + internal fun fromLeBytesUnsigned(source: ByteArray, offset: Int, length: Int): BigInteger { + val bytes = source.copyOfRange(offset, offset + length) + val unsigned = if (bytes[length - 1].toInt() and 0x80 != 0) { + bytes.copyOf(length + 1) // extra 0x00 forces non-negative + } else { + bytes + } + return BigInteger(canonicalize(unsigned)) + } + + // ---- Internal helpers ---- + + private fun longToLeBytes(value: Long): ByteArray { + val bytes = ByteArray(8) + var v = value + for (i in 0 until 8) { + bytes[i] = (v and 0xFF).toByte() + v = v shr 8 + } + return canonicalize(bytes) + } + + /** + * Strips redundant sign-extension bytes from the high end of LE two's complement bytes. + * Returns a minimal representation (at least 1 byte). + */ + internal fun canonicalize(bytes: ByteArray): ByteArray { + if (bytes.isEmpty()) return byteArrayOf(0) + var len = bytes.size + val isNegative = bytes[len - 1].toInt() and 0x80 != 0 + val signExt = if (isNegative) 0xFF.toByte() else 0x00.toByte() + + while (len > 1) { + if (bytes[len - 1] != signExt) break + // Can only strip if the next byte preserves the sign + if ((bytes[len - 2].toInt() and 0x80 != 0) != isNegative) break + len-- + } + return if (len == bytes.size) bytes else bytes.copyOfRange(0, len) + } + + /** Sign-extends LE bytes to the given [size]. */ + private fun signExtend(bytes: ByteArray, size: Int): ByteArray { + if (size <= bytes.size) return bytes + val result = bytes.copyOf(size) + if (bytes.last().toInt() and 0x80 != 0) { + for (i in bytes.size until size) result[i] = 0xFF.toByte() + } + return result + } + + private fun parseDecimal(str: String): BigInteger { + val isNeg = str.startsWith('-') + val digits = if (isNeg) str.substring(1) else str + require(digits.isNotEmpty() && digits.all { it in '0'..'9' }) { + "Invalid decimal string: $str" + } + + var magnitude = byteArrayOf(0) // LE unsigned magnitude + for (ch in digits) { + magnitude = multiplyByAndAdd(magnitude, 10, ch - '0') + } + + // Ensure the magnitude is positive in two's complement + if (magnitude.last().toInt() and 0x80 != 0) { + magnitude = magnitude.copyOf(magnitude.size + 1) // add 0x00 sign byte + } + + val canonical = canonicalize(magnitude) + return if (isNeg && !(canonical.size == 1 && canonical[0] == 0.toByte())) { + BigInteger(canonical).unaryMinus() + } else { + BigInteger(canonical) + } + } + + private fun parseHex(str: String): BigInteger { + val isNeg = str.startsWith('-') + val hexStr = if (isNeg) str.substring(1) else str + require(hexStr.isNotEmpty() && hexStr.all { it in '0'..'9' || it in 'a'..'f' || it in 'A'..'F' }) { + "Invalid hex string: $str" + } + + // Pad to even length, convert to BE bytes + val padded = if (hexStr.length % 2 != 0) "0$hexStr" else hexStr + val beBytes = ByteArray(padded.length / 2) { i -> + padded.substring(i * 2, i * 2 + 2).toInt(16).toByte() + } + + // Reverse to LE + val le = beBytes.reversedArray() + + // Ensure non-negative two's complement + val positive = if (le.isNotEmpty() && le.last().toInt() and 0x80 != 0) { + le.copyOf(le.size + 1) + } else { + le + } + + val canonical = canonicalize(positive) + return if (isNeg && !(canonical.size == 1 && canonical[0] == 0.toByte())) { + BigInteger(canonical).unaryMinus() + } else { + BigInteger(canonical) + } + } + + /** + * Multiplies an unsigned LE magnitude by [factor] and adds [addend]. + * Returns a new array one byte larger to accommodate overflow. + */ + private fun multiplyByAndAdd(bytes: ByteArray, factor: Int, addend: Int): ByteArray { + val result = ByteArray(bytes.size + 1) + var carry = addend + for (i in bytes.indices) { + val v = (bytes[i].toInt() and 0xFF) * factor + carry + result[i] = (v and 0xFF).toByte() + carry = v shr 8 + } + result[bytes.size] = (carry and 0xFF).toByte() + return result + } + } + + // ---- Arithmetic ---- + + /** Returns the sum of this and [other]. */ + public fun add(other: BigInteger): BigInteger { + val maxLen = maxOf(leBytes.size, other.leBytes.size) + 1 + val a = signExtend(leBytes, maxLen) + val b = signExtend(other.leBytes, maxLen) + + val result = ByteArray(maxLen) + var carry = 0 + for (i in 0 until maxLen) { + val sum = (a[i].toInt() and 0xFF) + (b[i].toInt() and 0xFF) + carry + result[i] = (sum and 0xFF).toByte() + carry = sum shr 8 + } + return BigInteger(canonicalize(result)) + } + + public operator fun plus(other: BigInteger): BigInteger = add(other) + public operator fun minus(other: BigInteger): BigInteger = add(-other) + + /** Returns the two's complement negation of this value. */ + public operator fun unaryMinus(): BigInteger { + if (signum() == 0) return this + // Sign-extend by 1 byte to handle overflow (e.g., negating -128 needs 9 bits for +128) + val extended = signExtend(leBytes, leBytes.size + 1) + // Invert all bits + for (i in extended.indices) { + extended[i] = extended[i].toInt().inv().toByte() + } + // Add 1 + var carry = 1 + for (i in extended.indices) { + val sum = (extended[i].toInt() and 0xFF) + carry + extended[i] = (sum and 0xFF).toByte() + carry = sum shr 8 + if (carry == 0) break + } + return BigInteger(canonicalize(extended)) + } + + /** Left-shifts this value by [n] bits. */ + public fun shl(n: Int): BigInteger { + require(n >= 0) { "Shift amount must be non-negative: $n" } + if (n == 0 || signum() == 0) return this + + val byteShift = n / 8 + val bitShift = n % 8 + + // Allocate: original size + byte shift + 1 for bit overflow + val newSize = leBytes.size + byteShift + 1 + val result = ByteArray(newSize) + + // Copy original bytes at the shifted position + leBytes.copyInto(result, byteShift) + + // Sign-extend the high bytes beyond the original data + if (signum() < 0) { + for (i in leBytes.size + byteShift until newSize) { + result[i] = 0xFF.toByte() + } + } + + // Apply bit shift + if (bitShift > 0) { + var carry = 0 + for (i in byteShift until newSize) { + val v = ((result[i].toInt() and 0xFF) shl bitShift) or carry + result[i] = (v and 0xFF).toByte() + carry = (v shr 8) and 0xFF + } + } + + return BigInteger(canonicalize(result)) + } + + // ---- Properties ---- + + /** Returns -1, 0, or 1 as this value is negative, zero, or positive. */ + public fun signum(): Int { + val isNeg = leBytes.last().toInt() and 0x80 != 0 + if (isNeg) return -1 + // Check if all bytes are zero + for (b in leBytes) { + if (b != 0.toByte()) return 1 + } + return 0 + } + + /** Returns true if this value fits in [n] bytes of signed two's complement. */ + internal fun fitsInSignedBytes(n: Int): Boolean = leBytes.size <= n + + /** Returns true if this non-negative value fits in [n] bytes of unsigned representation. */ + internal fun fitsInUnsignedBytes(n: Int): Boolean { + if (signum() < 0) return false + // Canonical positive value may have a trailing 0x00 sign byte. + // The unsigned magnitude is leBytes without that trailing sign byte. + return leBytes.size <= n || + (leBytes.size == n + 1 && leBytes[n] == 0.toByte()) + } + + // ---- Conversion ---- + + /** + * Returns the big-endian two's complement byte array representation. + * This matches the convention of `java.math.BigInteger.toByteArray()`. + */ + public fun toByteArray(): ByteArray = leBytes.reversedArray() + + /** + * Returns the big-endian two's complement byte array representation. + * Alias for [toByteArray] for compatibility with ionspin's extension function. + */ + public fun toTwosComplementByteArray(): ByteArray = toByteArray() + + /** + * Returns LE bytes at exactly [size] bytes, sign-extending or truncating as needed. + * Used for efficient BSATN writing and Identity/ConnectionId.toByteArray(). + */ + internal fun toLeBytesFixedWidth(size: Int): ByteArray { + val result = ByteArray(size) + writeLeBytes(result, 0, size) + return result + } + + /** + * Writes LE bytes directly into [dest] at [destOffset], padded with sign extension to [size] bytes. + * Zero-allocation write path for BsatnWriter. + */ + internal fun writeLeBytes(dest: ByteArray, destOffset: Int, size: Int) { + val copyLen = minOf(leBytes.size, size) + leBytes.copyInto(dest, destOffset, 0, copyLen) + if (copyLen < size) { + val padByte = if (signum() < 0) 0xFF.toByte() else 0x00.toByte() + for (i in copyLen until size) { + dest[destOffset + i] = padByte + } + } + } + + /** Returns the decimal string representation. */ + override fun toString(): String = toStringRadix(10) + + /** Returns the string representation in the given [radix] (10 or 16). */ + public fun toString(radix: Int): String = toStringRadix(radix) + + private fun toStringRadix(radix: Int): String = when (radix) { + 10 -> toDecimalString() + 16 -> toHexString() + else -> throw IllegalArgumentException("Unsupported radix: $radix") + } + + private fun toDecimalString(): String { + val sign = signum() + if (sign == 0) return "0" + + val isNeg = sign < 0 + // Work on a copy of the unsigned magnitude + val magnitude = if (isNeg) (-this).leBytes.copyOf() else leBytes.copyOf() + + val digits = StringBuilder() + while (!isAllZero(magnitude)) { + val remainder = divideByTenInPlace(magnitude) + digits.append(('0' + remainder)) + } + + if (isNeg) digits.append('-') + return digits.reverse().toString() + } + + private fun toHexString(): String { + val sign = signum() + if (sign == 0) return "0" + if (sign < 0) return "-" + (-this).toHexString() + + val sb = StringBuilder() + var leading = true + for (i in leBytes.size - 1 downTo 0) { + val b = leBytes[i].toInt() and 0xFF + val hi = b shr 4 + val lo = b and 0x0F + if (leading) { + if (hi != 0) { + sb.append(HEX_CHARS[hi]) + sb.append(HEX_CHARS[lo]) + leading = false + } else if (lo != 0) { + sb.append(HEX_CHARS[lo]) + leading = false + } + } else { + sb.append(HEX_CHARS[hi]) + sb.append(HEX_CHARS[lo]) + } + } + return if (sb.isEmpty()) "0" else sb.toString() + } + + // ---- Comparison and equality ---- + + override fun compareTo(other: BigInteger): Int { + val thisSign = signum() + val otherSign = other.signum() + + if (thisSign != otherSign) return thisSign.compareTo(otherSign) + if (thisSign == 0) return 0 + + // Same sign: sign-extend to equal length and compare from MSB + val maxLen = maxOf(leBytes.size, other.leBytes.size) + val a = signExtend(leBytes, maxLen) + val b = signExtend(other.leBytes, maxLen) + + for (i in maxLen - 1 downTo 0) { + val av = a[i].toInt() and 0xFF + val bv = b[i].toInt() and 0xFF + if (av != bv) return av.compareTo(bv) + } + return 0 + } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is BigInteger) return false + return leBytes.contentEquals(other.leBytes) + } + + override fun hashCode(): Int = leBytes.contentHashCode() + + // ---- Private helpers ---- + + private fun isAllZero(bytes: ByteArray): Boolean { + for (b in bytes) if (b != 0.toByte()) return false + return true + } + + /** + * Divides the unsigned LE magnitude in-place by 10 and returns the remainder (0-9). + * Processes from MSB (highest index) to LSB for schoolbook division. + */ + private fun divideByTenInPlace(bytes: ByteArray): Int { + var carry = 0 + for (i in bytes.size - 1 downTo 0) { + val cur = carry * 256 + (bytes[i].toInt() and 0xFF) + bytes[i] = (cur / 10).toByte() + carry = cur % 10 + } + return carry + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BoolExpr.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BoolExpr.kt new file mode 100644 index 00000000000..dce318ad04b --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BoolExpr.kt @@ -0,0 +1,20 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import kotlin.jvm.JvmInline + +/** + * A type-safe boolean SQL expression. + * The type parameter [TRow] tracks which table row type this expression applies to. + * Constructed via column comparison methods on [Col] and [IxCol]. + */ +@JvmInline +public value class BoolExpr<@Suppress("unused") TRow>(public val sql: String) { + /** Returns a new expression that is the logical AND of this and [other]. */ + public fun and(other: BoolExpr): BoolExpr = BoolExpr("($sql AND ${other.sql})") + + /** Returns a new expression that is the logical OR of this and [other]. */ + public fun or(other: BoolExpr): BoolExpr = BoolExpr("($sql OR ${other.sql})") + + /** Returns the logical negation of this expression. */ + public fun not(): BoolExpr = BoolExpr("(NOT $sql)") +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/CallbackList.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/CallbackList.kt new file mode 100644 index 00000000000..e51c226f07f --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/CallbackList.kt @@ -0,0 +1,28 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.persistentListOf + +/** + * Thread-safe callback list backed by an atomic persistent list. + * Reads are zero-copy snapshots; writes use atomic CAS. + */ +@InternalSpacetimeApi +public class CallbackList { + private val list = atomic(persistentListOf()) + + /** Registers a callback. */ + public fun add(cb: T) { list.update { it.add(cb) } } + /** Removes a previously registered callback. */ + public fun remove(cb: T) { list.update { it.remove(cb) } } + /** Whether this list contains no callbacks. */ + public fun isEmpty(): Boolean = list.value.isEmpty() + /** Whether this list contains at least one callback. */ + public fun isNotEmpty(): Boolean = list.value.isNotEmpty() + + /** Invokes [action] on a snapshot of currently registered callbacks. */ + public fun forEach(action: (T) -> Unit) { + for (item in list.value) action(item) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ClientCache.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ClientCache.kt new file mode 100644 index 00000000000..4d8d5ac476d --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ClientCache.kt @@ -0,0 +1,506 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.BsatnRowList +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.RowSizeHint +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.TableUpdateRows +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.persistentHashMapOf +import kotlinx.collections.immutable.persistentListOf + +/** + * Wrapper for ByteArray that provides structural equality/hashCode. + * Used as a map key for rows without a primary key (content-based keying via BSATN bytes). + */ +internal class BsatnRowKey(val bytes: ByteArray) { + override fun equals(other: Any?): Boolean = + other is BsatnRowKey && bytes.contentEquals(other.bytes) + + override fun hashCode(): Int = bytes.contentHashCode() +} + +/** + * Callback that fires after table operations are applied. + */ +internal fun interface PendingCallback { + /** Executes this deferred callback. */ + fun invoke() +} + +/** + * A decoded row paired with its raw BSATN bytes (used for content-based keying). + */ +internal data class DecodedRow(val row: Row, val rawBytes: ByteArray) { + override fun equals(other: Any?): Boolean = + other is DecodedRow<*> && row == other.row && rawBytes.contentEquals(other.rawBytes) + + override fun hashCode(): Int = 31 * row.hashCode() + rawBytes.contentHashCode() +} + +/** + * Type-erased marker for pre-decoded row data. + * Produced by [TableCache.parseUpdate] / [TableCache.parseDeletes], + * consumed by preApply/apply methods. + * rows are decoded once and the parsed result is passed to all phases. + */ +internal interface ParsedTableData + +internal class ParsedPersistentUpdate( + val deletes: List>, + val inserts: List>, +) : ParsedTableData + +internal class ParsedEventUpdate( + val events: List, +) : ParsedTableData + +internal class ParsedDeletesOnly( + val rows: List>, +) : ParsedTableData + +/** + * Per-table cache entry. Stores rows with reference counting + * to handle overlapping subscriptions. + * + * Rows are keyed by their primary key (or full encoded bytes if no PK). + * + * @param Row the row type stored in this cache + * @param Key the key type used to identify rows (typed PK or BsatnRowKey) + */ +@InternalSpacetimeApi +public class TableCache private constructor( + private val decode: (BsatnReader) -> Row, + private val keyExtractor: (Row, ByteArray) -> Key, +) { + public companion object { + /** Creates a table cache that keys rows by an extracted primary key. */ + @InternalSpacetimeApi + public fun withPrimaryKey( + decode: (BsatnReader) -> Row, + primaryKey: (Row) -> Key, + ): TableCache = TableCache(decode) { row, _ -> primaryKey(row) } + + /** Creates a table cache that keys rows by their full BSATN-encoded bytes. */ + @InternalSpacetimeApi + @Suppress("UNCHECKED_CAST") + public fun withContentKey( + decode: (BsatnReader) -> Row, + ): TableCache = TableCache(decode) { _, bytes -> BsatnRowKey(bytes) } + } + + // Map> — atomic persistent map for thread-safe reads + private val _rows = atomic(persistentHashMapOf>()) + + private val _onInsertCallbacks = atomic(persistentListOf<(EventContext, Row) -> Unit>()) + private val _onDeleteCallbacks = atomic(persistentListOf<(EventContext, Row) -> Unit>()) + private val _onUpdateCallbacks = atomic(persistentListOf<(EventContext, Row, Row) -> Unit>()) + private val _onBeforeDeleteCallbacks = atomic(persistentListOf<(EventContext, Row) -> Unit>()) + + private val _internalInsertListeners = atomic(persistentListOf<(Row) -> Unit>()) + private val _internalDeleteListeners = atomic(persistentListOf<(Row) -> Unit>()) + + internal fun addInternalInsertListener(cb: (Row) -> Unit) { _internalInsertListeners.update { it.add(cb) } } + internal fun addInternalDeleteListener(cb: (Row) -> Unit) { _internalDeleteListeners.update { it.add(cb) } } + + /** Registers a callback that fires after a row is inserted. */ + public fun onInsert(cb: (EventContext, Row) -> Unit) { _onInsertCallbacks.update { it.add(cb) } } + + /** Registers a callback that fires after a row is deleted. */ + public fun onDelete(cb: (EventContext, Row) -> Unit) { _onDeleteCallbacks.update { it.add(cb) } } + + /** Registers a callback that fires after a row is updated (old row, new row). */ + public fun onUpdate(cb: (EventContext, Row, Row) -> Unit) { _onUpdateCallbacks.update { it.add(cb) } } + + /** Registers a callback that fires before a row is deleted. */ + public fun onBeforeDelete(cb: (EventContext, Row) -> Unit) { _onBeforeDeleteCallbacks.update { it.add(cb) } } + + /** Removes a previously registered insert callback. */ + public fun removeOnInsert(cb: (EventContext, Row) -> Unit) { _onInsertCallbacks.update { it.remove(cb) } } + + /** Removes a previously registered delete callback. */ + public fun removeOnDelete(cb: (EventContext, Row) -> Unit) { _onDeleteCallbacks.update { it.remove(cb) } } + + /** Removes a previously registered update callback. */ + public fun removeOnUpdate(cb: (EventContext, Row, Row) -> Unit) { _onUpdateCallbacks.update { it.remove(cb) } } + + /** Removes a previously registered before-delete callback. */ + public fun removeOnBeforeDelete(cb: (EventContext, Row) -> Unit) { _onBeforeDeleteCallbacks.update { it.remove(cb) } } + + /** Returns the number of rows currently stored in this table. */ + public fun count(): Int = _rows.value.size + + /** Returns a lazy sequence over all rows in this table. */ + public fun iter(): Sequence = _rows.value.values.asSequence().map { it.first } + + /** Returns a snapshot list of all rows in this table. */ + public fun all(): List = _rows.value.values.map { it.first } + + /** + * Decode rows from a BsatnRowList, capturing raw BSATN bytes per row. + */ + private fun decodeRowListWithBytes(rowList: BsatnRowList): List> { + if (rowList.rowsSize == 0) return emptyList() + val reader = rowList.rowsReader + val result = mutableListOf>() + val rowCount = when (val hint = rowList.sizeHint) { + is RowSizeHint.FixedSize -> { + val rowSize = hint.size.toInt() + require(rowSize > 0) { "Server sent FixedSize(0), which violates the protocol invariant" } + require(rowList.rowsSize % rowSize == 0) { + "FixedSize row data not evenly divisible: ${rowList.rowsSize} bytes / $rowSize row size" + } + rowList.rowsSize / rowSize + } + is RowSizeHint.RowOffsets -> hint.offsets.size + } + repeat(rowCount) { + val startOffset = reader.offset + val row = decode(reader) + val rawBytes = reader.sliceArray(startOffset, reader.offset) + result.add(DecodedRow(row, rawBytes)) + } + return result + } + + /** Decodes all rows from a [BsatnRowList], discarding raw bytes. */ + internal fun decodeRowList(rowList: BsatnRowList): List = + decodeRowListWithBytes(rowList).map { it.row } + + // --- Parse phase: decode once, reuse across preApply/apply --- + + /** + * Decode a [TableUpdateRows] into a [ParsedTableData] that can be passed + * to [preApplyUpdate] and [applyUpdate]. Rows are decoded exactly once. + */ + internal fun parseUpdate(update: TableUpdateRows): ParsedTableData = when (update) { + is TableUpdateRows.PersistentTable -> ParsedPersistentUpdate( + deletes = decodeRowListWithBytes(update.deletes), + inserts = decodeRowListWithBytes(update.inserts), + ) + is TableUpdateRows.EventTable -> ParsedEventUpdate( + events = decodeRowListWithBytes(update.events).map { it.row }, + ) + } + + /** + * Decode a [BsatnRowList] of deletes into a [ParsedTableData] that can be + * passed to [preApplyDeletes] and [applyDeletes]. Rows are decoded exactly once. + */ + internal fun parseDeletes(rowList: BsatnRowList): ParsedTableData = + ParsedDeletesOnly(rows = decodeRowListWithBytes(rowList)) + + // --- Insert (single-phase, no pre-apply needed) --- + + /** + * Apply insert operations from a BsatnRowList. + * Returns pending callbacks to execute after all tables are updated. + */ + internal fun applyInserts(ctx: EventContext, rowList: BsatnRowList): List { + val decoded = decodeRowListWithBytes(rowList) + val callbacks = mutableListOf() + val newInserts = mutableListOf() + _rows.update { current -> + callbacks.clear() + newInserts.clear() + val insertCbs = _onInsertCallbacks.value + var snapshot = current + for ((row, rawBytes) in decoded) { + val id = keyExtractor(row, rawBytes) + val existing = snapshot[id] + if (existing != null) { + snapshot = snapshot.put(id, Pair(existing.first, existing.second + 1)) + } else { + snapshot = snapshot.put(id, Pair(row, 1)) + newInserts.add(row) + if (insertCbs.isNotEmpty()) { + callbacks.add(PendingCallback { + for (cb in insertCbs) cb(ctx, row) + }) + } + } + } + snapshot + } + for (row in newInserts) { + for (listener in _internalInsertListeners.value) listener(row) + } + return callbacks + } + + // --- Unsubscribe deletes (two-phase) --- + + /** + * Phase 1 for unsubscribe deletes: fires onBeforeDelete callbacks + * BEFORE any mutations happen, enabling cross-table consistency. + * Accepts pre-decoded data from [parseDeletes]. + */ + @Suppress("UNCHECKED_CAST") + internal fun preApplyDeletes(ctx: EventContext, parsed: ParsedTableData) { + if (_onBeforeDeleteCallbacks.value.isEmpty()) return + val data = parsed as ParsedDeletesOnly + val snapshot = _rows.value + for ((row, rawBytes) in data.rows) { + val id = keyExtractor(row, rawBytes) + val existing = snapshot[id] ?: continue + if (existing.second <= 1) { + for (cb in _onBeforeDeleteCallbacks.value) cb(ctx, existing.first) + } + } + } + + /** + * Phase 2 for unsubscribe deletes: mutates rows and returns post-mutation callbacks. + * onBeforeDelete must be called via [preApplyDeletes] before this. + * Accepts pre-decoded data from [parseDeletes]. + */ + @Suppress("UNCHECKED_CAST") + internal fun applyDeletes(ctx: EventContext, parsed: ParsedTableData): List { + val data = parsed as ParsedDeletesOnly + val callbacks = mutableListOf() + val removedRows = mutableListOf() + _rows.update { current -> + callbacks.clear() + removedRows.clear() + val deleteCbs = _onDeleteCallbacks.value + var snapshot = current + for ((row, rawBytes) in data.rows) { + val id = keyExtractor(row, rawBytes) + val existing = snapshot[id] ?: continue + if (existing.second <= 1) { + val capturedRow = existing.first + snapshot = snapshot.remove(id) + removedRows.add(capturedRow) + if (deleteCbs.isNotEmpty()) { + callbacks.add(PendingCallback { + for (cb in deleteCbs) cb(ctx, capturedRow) + }) + } + } else { + snapshot = snapshot.put(id, Pair(existing.first, existing.second - 1)) + } + } + snapshot + } + for (row in removedRows) { + for (listener in _internalDeleteListeners.value) listener(row) + } + return callbacks + } + + // --- Transaction updates (two-phase) --- + + /** + * Phase 1 for transaction updates: fires onBeforeDelete callbacks + * for rows that will be deleted (not updated), BEFORE any mutations happen. + * Accepts pre-decoded data from [parseUpdate]. + */ + @Suppress("UNCHECKED_CAST") + internal fun preApplyUpdate(ctx: EventContext, parsed: ParsedTableData) { + if (_onBeforeDeleteCallbacks.value.isEmpty()) return + val update = parsed as? ParsedPersistentUpdate ?: return + + // Build insert key set for update detection + val insertKeys = mutableSetOf() + for ((row, rawBytes) in update.inserts) insertKeys.add(keyExtractor(row, rawBytes)) + + // Fire onBeforeDelete for pure deletes only (not updates) + val snapshot = _rows.value + for ((row, rawBytes) in update.deletes) { + val id = keyExtractor(row, rawBytes) + if (id in insertKeys) continue // This is an update, not a delete + val existing = snapshot[id] ?: continue + if (existing.second <= 1) { + for (cb in _onBeforeDeleteCallbacks.value) cb(ctx, existing.first) + } + } + } + + /** + * Phase 2 for transaction updates: mutates rows and returns post-mutation callbacks. + * onBeforeDelete must be called via [preApplyUpdate] before this. + * Accepts pre-decoded data from [parseUpdate]. + */ + @Suppress("UNCHECKED_CAST") + internal fun applyUpdate(ctx: EventContext, parsed: ParsedTableData): List { + return when (parsed) { + is ParsedPersistentUpdate<*> -> { + val update = parsed as ParsedPersistentUpdate + + // Build delete map for pairing with inserts + val deleteMap = mutableMapOf() + for ((row, rawBytes) in update.deletes) deleteMap[keyExtractor(row, rawBytes)] = row + + val callbacks = mutableListOf() + val updatedRows = mutableListOf>() + val newInserts = mutableListOf() + val removedRows = mutableListOf() + + _rows.update { current -> + callbacks.clear() + updatedRows.clear() + newInserts.clear() + removedRows.clear() + val insertCbs = _onInsertCallbacks.value + val deleteCbs = _onDeleteCallbacks.value + val updateCbs = _onUpdateCallbacks.value + val localDeleteMap = deleteMap.toMutableMap() + var snapshot = current + + // Process inserts — check for matching delete (= update) + for ((row, rawBytes) in update.inserts) { + val id = keyExtractor(row, rawBytes) + val deletedRow = localDeleteMap.remove(id) + if (deletedRow != null) { + // Update: same key in both insert and delete + val oldRow = snapshot[id]?.first ?: deletedRow + snapshot = snapshot.put(id, Pair(row, snapshot[id]?.second ?: 1)) + updatedRows.add(oldRow to row) + if (updateCbs.isNotEmpty()) { + callbacks.add(PendingCallback { + for (cb in updateCbs) cb(ctx, oldRow, row) + }) + } + } else { + // Pure insert + val existing = snapshot[id] + if (existing != null) { + snapshot = snapshot.put(id, Pair(existing.first, existing.second + 1)) + } else { + snapshot = snapshot.put(id, Pair(row, 1)) + newInserts.add(row) + if (insertCbs.isNotEmpty()) { + callbacks.add(PendingCallback { + for (cb in insertCbs) cb(ctx, row) + }) + } + } + } + } + + // Remaining deletes: pure deletes (onBeforeDelete already fired in preApplyUpdate) + for ((id, _) in localDeleteMap) { + val existing = snapshot[id] ?: continue + if (existing.second <= 1) { + val capturedRow = existing.first + snapshot = snapshot.remove(id) + removedRows.add(capturedRow) + if (deleteCbs.isNotEmpty()) { + callbacks.add(PendingCallback { + for (cb in deleteCbs) cb(ctx, capturedRow) + }) + } + } else { + snapshot = snapshot.put(id, Pair(existing.first, existing.second - 1)) + } + } + + snapshot + } + + // Fire internal listeners after CAS succeeds + for ((oldRow, newRow) in updatedRows) { + for (listener in _internalDeleteListeners.value) listener(oldRow) + for (listener in _internalInsertListeners.value) listener(newRow) + } + for (row in newInserts) { + for (listener in _internalInsertListeners.value) listener(row) + } + for (row in removedRows) { + for (listener in _internalDeleteListeners.value) listener(row) + } + + callbacks + } + is ParsedEventUpdate<*> -> { + // Event table: fire insert callbacks, but don't store + val events = (parsed as ParsedEventUpdate).events + val insertCbs = _onInsertCallbacks.value + val callbacks = mutableListOf() + for (row in events) { + if (insertCbs.isNotEmpty()) { + callbacks.add(PendingCallback { + for (cb in insertCbs) cb(ctx, row) + }) + } + } + callbacks + } + else -> emptyList() + } + } + + /** + * Clear all rows (used on disconnect). + */ + internal fun clear() { + val oldRows = _rows.getAndSet(persistentHashMapOf()) + val listeners = _internalDeleteListeners.value + if (listeners.isNotEmpty()) { + for ((_, pair) in oldRows) { + for (listener in listeners) listener(pair.first) + } + } + } +} + +/** + * Client-side cache holding all table caches. + * Registry of [TableCache] instances keyed by table name. + */ +@InternalSpacetimeApi +public class ClientCache { + private val _tables = atomic(persistentHashMapOf>()) + + /** Registers a [TableCache] under the given table name. */ + @InternalSpacetimeApi + public fun register(tableName: String, cache: TableCache) { + _tables.update { it.put(tableName, cache) } + } + + /** Returns the table cache for [tableName], throwing if not registered. */ + @Suppress("UNCHECKED_CAST") + internal fun getTable(tableName: String): TableCache = + _tables.value[tableName] as? TableCache + ?: error("Table '$tableName' not found in client cache") + + /** Returns the table cache for [tableName], or `null` if not registered. */ + @Suppress("UNCHECKED_CAST") + internal fun getTableOrNull(tableName: String): TableCache? = + _tables.value[tableName] as? TableCache + + /** Returns the table cache for [tableName], creating it via [factory] if not yet registered. */ + @InternalSpacetimeApi + @Suppress("UNCHECKED_CAST") + public fun getOrCreateTable(tableName: String, factory: () -> TableCache): TableCache { + // Fast path: already registered + _tables.value[tableName]?.let { return it as TableCache } + + // Create once outside the CAS loop so factory() is never called on retry + val created = factory() + var result: TableCache? = null + _tables.update { map -> + val existing = map[tableName] + if (existing != null) { + result = existing as TableCache + map + } else { + result = created + map.put(tableName, created) + } + } + return result!! + } + + /** Returns the table cache for [tableName] without casting, or `null` if not registered. */ + internal fun getUntypedTable(tableName: String): TableCache<*, *>? = + _tables.value[tableName] + + /** Returns the set of all registered table names. */ + internal fun tableNames(): Set = _tables.value.keys + + /** Clears all rows from every registered table cache. */ + internal fun clear() { + for ((_, table) in _tables.value) table.clear() + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Col.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Col.kt new file mode 100644 index 00000000000..04ed489874d --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Col.kt @@ -0,0 +1,77 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +/** + * A typed reference to a table column. + * Supports all comparison operators (eq, neq, lt, lte, gt, gte). + * + * @param TRow the row type this column belongs to + * @param TValue the Kotlin type of this column's value + */ +public class Col @InternalSpacetimeApi constructor(tableName: String, columnName: String) { + internal val refSql: String = "${SqlFormat.quoteIdent(tableName)}.${SqlFormat.quoteIdent(columnName)}" + + /** Tests equality against a literal value. */ + public fun eq(value: SqlLiteral): BoolExpr = BoolExpr("($refSql = ${value.sql})") + + /** Tests equality against another column. */ + public fun eq(other: Col): BoolExpr = BoolExpr("($refSql = ${other.refSql})") + + /** Tests inequality against a literal value. */ + public fun neq(value: SqlLiteral): BoolExpr = BoolExpr("($refSql <> ${value.sql})") + + /** Tests inequality against another column. */ + public fun neq(other: Col): BoolExpr = BoolExpr("($refSql <> ${other.refSql})") + + /** Tests whether this column is strictly less than [value]. */ + public fun lt(value: SqlLiteral): BoolExpr = BoolExpr("($refSql < ${value.sql})") + + /** Tests whether this column is less than or equal to [value]. */ + public fun lte(value: SqlLiteral): BoolExpr = BoolExpr("($refSql <= ${value.sql})") + + /** Tests whether this column is strictly greater than [value]. */ + public fun gt(value: SqlLiteral): BoolExpr = BoolExpr("($refSql > ${value.sql})") + + /** Tests whether this column is greater than or equal to [value]. */ + public fun gte(value: SqlLiteral): BoolExpr = BoolExpr("($refSql >= ${value.sql})") +} + +/** + * A typed reference to an indexed column. + * Supports eq/neq comparisons and indexed join equality. + */ +public class IxCol @InternalSpacetimeApi constructor(tableName: String, columnName: String) { + internal val refSql: String = "${SqlFormat.quoteIdent(tableName)}.${SqlFormat.quoteIdent(columnName)}" + + /** Tests equality against a literal value. */ + public fun eq(value: SqlLiteral): BoolExpr = BoolExpr("($refSql = ${value.sql})") + + /** Creates an indexed join equality condition against another indexed column. */ + @OptIn(InternalSpacetimeApi::class) + public fun eq(other: IxCol): IxJoinEq = + IxJoinEq(refSql, other.refSql) + + /** Tests inequality against a literal value. */ + public fun neq(value: SqlLiteral): BoolExpr = BoolExpr("($refSql <> ${value.sql})") + + /** Tests whether this column is strictly less than [value]. */ + public fun lt(value: SqlLiteral): BoolExpr = BoolExpr("($refSql < ${value.sql})") + + /** Tests whether this column is less than or equal to [value]. */ + public fun lte(value: SqlLiteral): BoolExpr = BoolExpr("($refSql <= ${value.sql})") + + /** Tests whether this column is strictly greater than [value]. */ + public fun gt(value: SqlLiteral): BoolExpr = BoolExpr("($refSql > ${value.sql})") + + /** Tests whether this column is greater than or equal to [value]. */ + public fun gte(value: SqlLiteral): BoolExpr = BoolExpr("($refSql >= ${value.sql})") +} + +/** + * Represents an indexed equality join condition between two tables. + * Created by calling [IxCol.eq] with another indexed column. + * Used as the `on` parameter for semi-join methods. + */ +public class IxJoinEq<@Suppress("unused") TLeftRow, @Suppress("unused") TRightRow> @InternalSpacetimeApi constructor( + internal val leftRefSql: String, + internal val rightRefSql: String, +) diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ColExtensions.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ColExtensions.kt new file mode 100644 index 00000000000..5a03ca61106 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ColExtensions.kt @@ -0,0 +1,205 @@ +@file:Suppress("TooManyFunctions") + +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.SpacetimeUuid + +/** + * Type-specialized comparison extensions for [Col] and [IxCol]. + * + * Each overload accepts a native Kotlin value, converts it to a [SqlLiteral] via [SqlLit], + * and delegates to the underlying column comparison method. This avoids requiring callers + * to wrap every value in [SqlLit] manually. + */ + +// ---- Col ---- + +public fun Col.eq(value: String): BoolExpr = eq(SqlLit.string(value)) +public fun Col.neq(value: String): BoolExpr = neq(SqlLit.string(value)) +public fun Col.lt(value: String): BoolExpr = lt(SqlLit.string(value)) +public fun Col.lte(value: String): BoolExpr = lte(SqlLit.string(value)) +public fun Col.gt(value: String): BoolExpr = gt(SqlLit.string(value)) +public fun Col.gte(value: String): BoolExpr = gte(SqlLit.string(value)) + +public fun IxCol.eq(value: String): BoolExpr = eq(SqlLit.string(value)) +public fun IxCol.neq(value: String): BoolExpr = neq(SqlLit.string(value)) + +// ---- Col ---- + +public fun Col.eq(value: Boolean): BoolExpr = eq(SqlLit.bool(value)) +public fun Col.neq(value: Boolean): BoolExpr = neq(SqlLit.bool(value)) +public operator fun Col.not(): BoolExpr = eq(SqlLit.bool(true)).not() + +public fun IxCol.eq(value: Boolean): BoolExpr = eq(SqlLit.bool(value)) +public fun IxCol.neq(value: Boolean): BoolExpr = neq(SqlLit.bool(value)) +public operator fun IxCol.not(): BoolExpr = eq(SqlLit.bool(true)).not() + +// ---- Col ---- + +public fun Col.eq(value: Int): BoolExpr = eq(SqlLit.int(value)) +public fun Col.neq(value: Int): BoolExpr = neq(SqlLit.int(value)) +public fun Col.lt(value: Int): BoolExpr = lt(SqlLit.int(value)) +public fun Col.lte(value: Int): BoolExpr = lte(SqlLit.int(value)) +public fun Col.gt(value: Int): BoolExpr = gt(SqlLit.int(value)) +public fun Col.gte(value: Int): BoolExpr = gte(SqlLit.int(value)) + +public fun IxCol.eq(value: Int): BoolExpr = eq(SqlLit.int(value)) +public fun IxCol.neq(value: Int): BoolExpr = neq(SqlLit.int(value)) + +// ---- Col ---- + +public fun Col.eq(value: Long): BoolExpr = eq(SqlLit.long(value)) +public fun Col.neq(value: Long): BoolExpr = neq(SqlLit.long(value)) +public fun Col.lt(value: Long): BoolExpr = lt(SqlLit.long(value)) +public fun Col.lte(value: Long): BoolExpr = lte(SqlLit.long(value)) +public fun Col.gt(value: Long): BoolExpr = gt(SqlLit.long(value)) +public fun Col.gte(value: Long): BoolExpr = gte(SqlLit.long(value)) + +public fun IxCol.eq(value: Long): BoolExpr = eq(SqlLit.long(value)) +public fun IxCol.neq(value: Long): BoolExpr = neq(SqlLit.long(value)) + +// ---- Col ---- + +public fun Col.eq(value: Byte): BoolExpr = eq(SqlLit.byte(value)) +public fun Col.neq(value: Byte): BoolExpr = neq(SqlLit.byte(value)) +public fun Col.lt(value: Byte): BoolExpr = lt(SqlLit.byte(value)) +public fun Col.lte(value: Byte): BoolExpr = lte(SqlLit.byte(value)) +public fun Col.gt(value: Byte): BoolExpr = gt(SqlLit.byte(value)) +public fun Col.gte(value: Byte): BoolExpr = gte(SqlLit.byte(value)) + +public fun Col.eq(value: Short): BoolExpr = eq(SqlLit.short(value)) +public fun Col.neq(value: Short): BoolExpr = neq(SqlLit.short(value)) +public fun Col.lt(value: Short): BoolExpr = lt(SqlLit.short(value)) +public fun Col.lte(value: Short): BoolExpr = lte(SqlLit.short(value)) +public fun Col.gt(value: Short): BoolExpr = gt(SqlLit.short(value)) +public fun Col.gte(value: Short): BoolExpr = gte(SqlLit.short(value)) + +public fun Col.eq(value: UByte): BoolExpr = eq(SqlLit.ubyte(value)) +public fun Col.neq(value: UByte): BoolExpr = neq(SqlLit.ubyte(value)) +public fun Col.lt(value: UByte): BoolExpr = lt(SqlLit.ubyte(value)) +public fun Col.lte(value: UByte): BoolExpr = lte(SqlLit.ubyte(value)) +public fun Col.gt(value: UByte): BoolExpr = gt(SqlLit.ubyte(value)) +public fun Col.gte(value: UByte): BoolExpr = gte(SqlLit.ubyte(value)) + +public fun Col.eq(value: UShort): BoolExpr = eq(SqlLit.ushort(value)) +public fun Col.neq(value: UShort): BoolExpr = neq(SqlLit.ushort(value)) +public fun Col.lt(value: UShort): BoolExpr = lt(SqlLit.ushort(value)) +public fun Col.lte(value: UShort): BoolExpr = lte(SqlLit.ushort(value)) +public fun Col.gt(value: UShort): BoolExpr = gt(SqlLit.ushort(value)) +public fun Col.gte(value: UShort): BoolExpr = gte(SqlLit.ushort(value)) + +public fun Col.eq(value: UInt): BoolExpr = eq(SqlLit.uint(value)) +public fun Col.neq(value: UInt): BoolExpr = neq(SqlLit.uint(value)) +public fun Col.lt(value: UInt): BoolExpr = lt(SqlLit.uint(value)) +public fun Col.lte(value: UInt): BoolExpr = lte(SqlLit.uint(value)) +public fun Col.gt(value: UInt): BoolExpr = gt(SqlLit.uint(value)) +public fun Col.gte(value: UInt): BoolExpr = gte(SqlLit.uint(value)) + +public fun Col.eq(value: ULong): BoolExpr = eq(SqlLit.ulong(value)) +public fun Col.neq(value: ULong): BoolExpr = neq(SqlLit.ulong(value)) +public fun Col.lt(value: ULong): BoolExpr = lt(SqlLit.ulong(value)) +public fun Col.lte(value: ULong): BoolExpr = lte(SqlLit.ulong(value)) +public fun Col.gt(value: ULong): BoolExpr = gt(SqlLit.ulong(value)) +public fun Col.gte(value: ULong): BoolExpr = gte(SqlLit.ulong(value)) + +public fun Col.eq(value: Float): BoolExpr = eq(SqlLit.float(value)) +public fun Col.neq(value: Float): BoolExpr = neq(SqlLit.float(value)) +public fun Col.lt(value: Float): BoolExpr = lt(SqlLit.float(value)) +public fun Col.lte(value: Float): BoolExpr = lte(SqlLit.float(value)) +public fun Col.gt(value: Float): BoolExpr = gt(SqlLit.float(value)) +public fun Col.gte(value: Float): BoolExpr = gte(SqlLit.float(value)) + +public fun Col.eq(value: Double): BoolExpr = eq(SqlLit.double(value)) +public fun Col.neq(value: Double): BoolExpr = neq(SqlLit.double(value)) +public fun Col.lt(value: Double): BoolExpr = lt(SqlLit.double(value)) +public fun Col.lte(value: Double): BoolExpr = lte(SqlLit.double(value)) +public fun Col.gt(value: Double): BoolExpr = gt(SqlLit.double(value)) +public fun Col.gte(value: Double): BoolExpr = gte(SqlLit.double(value)) + +public fun IxCol.eq(value: Byte): BoolExpr = eq(SqlLit.byte(value)) +public fun IxCol.neq(value: Byte): BoolExpr = neq(SqlLit.byte(value)) + +public fun IxCol.eq(value: Short): BoolExpr = eq(SqlLit.short(value)) +public fun IxCol.neq(value: Short): BoolExpr = neq(SqlLit.short(value)) + +public fun IxCol.eq(value: UByte): BoolExpr = eq(SqlLit.ubyte(value)) +public fun IxCol.neq(value: UByte): BoolExpr = neq(SqlLit.ubyte(value)) + +public fun IxCol.eq(value: UShort): BoolExpr = eq(SqlLit.ushort(value)) +public fun IxCol.neq(value: UShort): BoolExpr = neq(SqlLit.ushort(value)) + +public fun IxCol.eq(value: UInt): BoolExpr = eq(SqlLit.uint(value)) +public fun IxCol.neq(value: UInt): BoolExpr = neq(SqlLit.uint(value)) + +public fun IxCol.eq(value: ULong): BoolExpr = eq(SqlLit.ulong(value)) +public fun IxCol.neq(value: ULong): BoolExpr = neq(SqlLit.ulong(value)) + +public fun IxCol.eq(value: Float): BoolExpr = eq(SqlLit.float(value)) +public fun IxCol.neq(value: Float): BoolExpr = neq(SqlLit.float(value)) + +public fun IxCol.eq(value: Double): BoolExpr = eq(SqlLit.double(value)) +public fun IxCol.neq(value: Double): BoolExpr = neq(SqlLit.double(value)) + +// ---- Col ---- + +public fun Col.eq(value: Int128): BoolExpr = eq(SqlLit.int128(value)) +public fun Col.neq(value: Int128): BoolExpr = neq(SqlLit.int128(value)) +public fun Col.lt(value: Int128): BoolExpr = lt(SqlLit.int128(value)) +public fun Col.lte(value: Int128): BoolExpr = lte(SqlLit.int128(value)) +public fun Col.gt(value: Int128): BoolExpr = gt(SqlLit.int128(value)) +public fun Col.gte(value: Int128): BoolExpr = gte(SqlLit.int128(value)) + +public fun Col.eq(value: UInt128): BoolExpr = eq(SqlLit.uint128(value)) +public fun Col.neq(value: UInt128): BoolExpr = neq(SqlLit.uint128(value)) +public fun Col.lt(value: UInt128): BoolExpr = lt(SqlLit.uint128(value)) +public fun Col.lte(value: UInt128): BoolExpr = lte(SqlLit.uint128(value)) +public fun Col.gt(value: UInt128): BoolExpr = gt(SqlLit.uint128(value)) +public fun Col.gte(value: UInt128): BoolExpr = gte(SqlLit.uint128(value)) + +public fun Col.eq(value: Int256): BoolExpr = eq(SqlLit.int256(value)) +public fun Col.neq(value: Int256): BoolExpr = neq(SqlLit.int256(value)) +public fun Col.lt(value: Int256): BoolExpr = lt(SqlLit.int256(value)) +public fun Col.lte(value: Int256): BoolExpr = lte(SqlLit.int256(value)) +public fun Col.gt(value: Int256): BoolExpr = gt(SqlLit.int256(value)) +public fun Col.gte(value: Int256): BoolExpr = gte(SqlLit.int256(value)) + +public fun Col.eq(value: UInt256): BoolExpr = eq(SqlLit.uint256(value)) +public fun Col.neq(value: UInt256): BoolExpr = neq(SqlLit.uint256(value)) +public fun Col.lt(value: UInt256): BoolExpr = lt(SqlLit.uint256(value)) +public fun Col.lte(value: UInt256): BoolExpr = lte(SqlLit.uint256(value)) +public fun Col.gt(value: UInt256): BoolExpr = gt(SqlLit.uint256(value)) +public fun Col.gte(value: UInt256): BoolExpr = gte(SqlLit.uint256(value)) + +public fun IxCol.eq(value: Int128): BoolExpr = eq(SqlLit.int128(value)) +public fun IxCol.neq(value: Int128): BoolExpr = neq(SqlLit.int128(value)) + +public fun IxCol.eq(value: UInt128): BoolExpr = eq(SqlLit.uint128(value)) +public fun IxCol.neq(value: UInt128): BoolExpr = neq(SqlLit.uint128(value)) + +public fun IxCol.eq(value: Int256): BoolExpr = eq(SqlLit.int256(value)) +public fun IxCol.neq(value: Int256): BoolExpr = neq(SqlLit.int256(value)) + +public fun IxCol.eq(value: UInt256): BoolExpr = eq(SqlLit.uint256(value)) +public fun IxCol.neq(value: UInt256): BoolExpr = neq(SqlLit.uint256(value)) + +// ---- Col ---- + +public fun Col.eq(value: Identity): BoolExpr = eq(SqlLit.identity(value)) +public fun Col.neq(value: Identity): BoolExpr = neq(SqlLit.identity(value)) + +public fun IxCol.eq(value: Identity): BoolExpr = eq(SqlLit.identity(value)) +public fun IxCol.neq(value: Identity): BoolExpr = neq(SqlLit.identity(value)) + +public fun Col.eq(value: ConnectionId): BoolExpr = eq(SqlLit.connectionId(value)) +public fun Col.neq(value: ConnectionId): BoolExpr = neq(SqlLit.connectionId(value)) + +public fun IxCol.eq(value: ConnectionId): BoolExpr = eq(SqlLit.connectionId(value)) +public fun IxCol.neq(value: ConnectionId): BoolExpr = neq(SqlLit.connectionId(value)) + +public fun Col.eq(value: SpacetimeUuid): BoolExpr = eq(SqlLit.uuid(value)) +public fun Col.neq(value: SpacetimeUuid): BoolExpr = neq(SqlLit.uuid(value)) + +public fun IxCol.eq(value: SpacetimeUuid): BoolExpr = eq(SqlLit.uuid(value)) +public fun IxCol.neq(value: SpacetimeUuid): BoolExpr = neq(SqlLit.uuid(value)) diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/DbConnection.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/DbConnection.kt new file mode 100644 index 00000000000..1168473334e --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/DbConnection.kt @@ -0,0 +1,1059 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ClientMessage +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ProcedureStatus +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QueryResult +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QuerySetId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ReducerOutcome +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ServerMessage +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.TransactionUpdate +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.UnsubscribeFlags +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.availableCompressionModes +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.defaultCompressionMode +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.transport.SpacetimeTransport +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.transport.Transport +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.TimeDuration +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import io.ktor.client.HttpClient +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.getAndUpdate +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.persistentHashMapOf +import kotlinx.collections.immutable.persistentListOf +import kotlinx.collections.immutable.toPersistentList +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job +import kotlinx.coroutines.NonCancellable +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.currentCoroutineContext +import kotlinx.coroutines.ensureActive +import kotlinx.coroutines.launch +import kotlinx.coroutines.suspendCancellableCoroutine +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout +import kotlin.coroutines.resume +import kotlin.time.Duration + +/** + * Tracks reducer call info so we can populate the Event.Reducer + * with the correct name/args when the result comes back. + */ +private class ReducerCallInfo( + val name: String, + val typedArgs: Any, +) + +/** + * Decodes a BSATN-encoded reducer error into a human-readable string. + * Reducer errors are BSATN strings (u32 length + UTF-8 bytes). + * Falls back to hex dump if decoding fails. + */ +private fun decodeReducerError(bytes: ByteArray): String { + return try { + val reader = BsatnReader(bytes) + reader.readString() + } catch (_: Exception) { + "Reducer returned undecodable BSATN error bytes (len=${bytes.size})" + } +} + + +/** + * Compression mode for the WebSocket connection. + */ +public enum class CompressionMode(internal val wireValue: String) { + /** Brotli compression (JVM/Android only). */ + BROTLI("Brotli"), + /** Gzip compression. */ + GZIP("Gzip"), + /** No compression. */ + NONE("None"), +} + +/** + * Connection lifecycle state machine. + * + * Each variant owns the resources created in that phase. + * [Connected] carries the coroutine jobs and exposes [Connected.shutdown] + * to cancel/join them before the cache is cleared — preventing the + * index-vs-_rows inconsistency that occurs when a CAS loop is still + * in flight. + * + * ``` + * Disconnected ──▶ Connecting ──▶ Connected ──▶ Closed + * │ ▲ + * └──────────────────────────┘ + * ``` + */ +public sealed interface ConnectionState { + /** No connection has been established yet. */ + public data object Disconnected : ConnectionState + /** A connection attempt is in progress. */ + public data object Connecting : ConnectionState + + /** The WebSocket connection is active and processing messages. */ + public class Connected internal constructor( + internal val receiveJob: Job, + internal val sendJob: Job, + ) : ConnectionState { + /** + * Cancel and await the active connection's coroutines. + * When called from within the receive loop (e.g. SubscriptionError + * with null requestId triggers disconnect()), [callerJob] matches + * [receiveJob] and both joins are skipped to avoid deadlock. + */ + internal suspend fun shutdown(callerJob: Job?) { + receiveJob.cancel() + sendJob.cancel() + if (callerJob != receiveJob) { + receiveJob.join() + sendJob.join() + } + } + } + + /** The connection has been closed and cannot be reused. */ + public data object Closed : ConnectionState +} + +/** + * Main entry point for connecting to a SpacetimeDB module. + * + * Handles: + * - WebSocket connection lifecycle + * - Message send/receive loop + * - Client cache management + * - Subscription tracking + * - Reducer call tracking + */ +public open class DbConnection internal constructor( + private val transport: Transport, + private val scope: CoroutineScope, + onConnectCallbacks: List<(DbConnectionView, Identity, String) -> Unit>, + onDisconnectCallbacks: List<(DbConnectionView, Throwable?) -> Unit>, + onConnectErrorCallbacks: List<(DbConnectionView, Throwable) -> Unit>, + private val clientConnectionId: ConnectionId, + /** Performance statistics for this connection (request latencies, message counts, etc.). */ + public val stats: Stats, + internal val moduleDescriptor: ModuleDescriptor?, + private val callbackDispatcher: CoroutineDispatcher?, +) : DbConnectionView { + /** Local cache of subscribed table rows, kept in sync with the server. */ + @InternalSpacetimeApi + public val clientCache: ClientCache = ClientCache() + + private val _moduleTables = atomic(null) + public override var moduleTables: ModuleTables? + get() = _moduleTables.value + internal set(value) { _moduleTables.value = value } + + private val _moduleReducers = atomic(null) + public override var moduleReducers: ModuleReducers? + get() = _moduleReducers.value + internal set(value) { _moduleReducers.value = value } + + private val _moduleProcedures = atomic(null) + public override var moduleProcedures: ModuleProcedures? + get() = _moduleProcedures.value + internal set(value) { _moduleProcedures.value = value } + + private val _identity = atomic(null) + public override val identity: Identity? + get() = _identity.value + + private val _connectionId = atomic(null) + public override val connectionId: ConnectionId? + get() = _connectionId.value + + private val _token = atomic(null) + /** Authentication token assigned by the server, or `null` before connection. */ + public var token: String? + get() = _token.value + private set(value) { _token.value = value } + + private val _state = atomic(ConnectionState.Disconnected) + public override val isActive: Boolean get() = _state.value is ConnectionState.Connected + + private val sendChannel = Channel(Channel.UNLIMITED) + private val _nextQuerySetId = atomic(0) + private val subscriptions = atomic(persistentHashMapOf()) + private val reducerCallbacks = + atomic(persistentHashMapOf) -> Unit>()) + private val reducerCallInfo = atomic(persistentHashMapOf()) + private val procedureCallbacks = + atomic(persistentHashMapOf Unit>()) + private val oneOffQueryCallbacks = + atomic(persistentHashMapOf) -> Unit>()) + private val querySetIdToRequestId = atomic(persistentHashMapOf()) + private val _eventId = atomic(0L) + private val _onConnectCallbacks = onConnectCallbacks.toList() + private val _onDisconnectCallbacks = atomic(onDisconnectCallbacks.toPersistentList()) + private val _onConnectErrorCallbacks = atomic(onConnectErrorCallbacks.toPersistentList()) + + // --- Connection callbacks --- + + public override fun onDisconnect(cb: (DbConnectionView, Throwable?) -> Unit) { + _onDisconnectCallbacks.update { it.add(cb) } + } + + public override fun removeOnDisconnect(cb: (DbConnectionView, Throwable?) -> Unit) { + _onDisconnectCallbacks.update { it.remove(cb) } + } + + public override fun onConnectError(cb: (DbConnectionView, Throwable) -> Unit) { + _onConnectErrorCallbacks.update { it.add(cb) } + } + + public override fun removeOnConnectError(cb: (DbConnectionView, Throwable) -> Unit) { + _onConnectErrorCallbacks.update { it.remove(cb) } + } + + private fun nextEventId(): String { + val id = _eventId.incrementAndGet() + return "${connectionId?.toHexString() ?: clientConnectionId.toHexString()}:$id" + } + + /** + * Run a user callback, optionally dispatching to the configured [callbackDispatcher]. + * When no dispatcher is set, callbacks run on the current (receive-loop) thread. + * Catches and logs exceptions from user code without crashing the receive loop. + */ + internal suspend fun runUserCallback(block: () -> Unit) { + try { + val dispatcher = callbackDispatcher + if (dispatcher != null) { + withContext(dispatcher) { block() } + } else { + block() + } + } catch (e: Exception) { + currentCoroutineContext().ensureActive() + Logger.exception(e) + } + } + + /** + * Connect to SpacetimeDB and start the message receive loop. + * Called internally by [Builder.build]. Not intended for direct use. + * + * If the transport fails to connect, [onConnectError] callbacks are fired + * and the connection transitions to [ConnectionState.Closed]. + * No exception is thrown — errors are reported via callbacks. + */ + internal suspend fun connect() { + val disconnected = _state.value as? ConnectionState.Disconnected + ?: error( + if (_state.value is ConnectionState.Closed) + "Connection is closed. Create a new DbConnection to reconnect." + else + "connect() called in invalid state: ${_state.value}" + ) + check(_state.compareAndSet(disconnected, ConnectionState.Connecting)) { + "connect() called in invalid state: ${_state.value}" + } + Logger.info { "Connecting to SpacetimeDB..." } + try { + transport.connect() + } catch (e: Exception) { + _state.value = ConnectionState.Closed + scope.cancel() + for (cb in _onConnectErrorCallbacks.value) runUserCallback { cb(this, e) } + return + } + + // Start sender coroutine — drains any buffered messages in FIFO order + val sendJob = scope.launch { + for (msg in sendChannel) { + transport.send(msg) + } + } + + // Start receive loop + val receiveJob = scope.launch { + try { + transport.incoming().collect { message -> + val applyStart = kotlin.time.TimeSource.Monotonic.markNow() + processMessage(message) + stats.applyMessageTracker.insertSample(applyStart.elapsedNow()) + } + // Normal completion — server closed the connection + _state.value = ConnectionState.Closed + sendChannel.close() + failPendingOperations() + val cbs = _onDisconnectCallbacks.getAndSet(persistentListOf()) + for (cb in cbs) runUserCallback { cb(this@DbConnection, null) } + clientCache.clear() + } catch (e: Exception) { + currentCoroutineContext().ensureActive() + Logger.error { "Connection error: ${e.message}" } + _state.value = ConnectionState.Closed + sendChannel.close() + failPendingOperations() + val cbs = _onDisconnectCallbacks.getAndSet(persistentListOf()) + for (cb in cbs) runUserCallback { cb(this@DbConnection, e) } + clientCache.clear() + } finally { + withContext(NonCancellable) { + sendChannel.close() + try { transport.disconnect() } catch (_: Exception) {} + } + } + } + + _state.compareAndSet(ConnectionState.Connecting, ConnectionState.Connected(receiveJob, sendJob)) + } + + /** + * Disconnect from SpacetimeDB and release all resources. + * The connection cannot be reused — create a new [DbConnection] to reconnect. + * + * @param reason if non-null, passed to onDisconnect callbacks to distinguish + * error-driven disconnects from graceful ones. + */ + public override suspend fun disconnect(reason: Throwable?) { + val prev = _state.getAndSet(ConnectionState.Closed) + if (prev is ConnectionState.Disconnected || prev is ConnectionState.Closed) return + Logger.info { "Disconnecting from SpacetimeDB" } + // Close the send channel FIRST so concurrent callReducer/oneOffQuery/etc. + // calls fail immediately instead of enqueuing messages that will never + // get responses. This eliminates the TOCTOU window between state=CLOSED + // and the channel close that previously lived in the receive job's finally block. + // (Double-close is safe for Channels — it's a no-op.) + sendChannel.close() + if (prev is ConnectionState.Connected) { + prev.shutdown(currentCoroutineContext()[Job]) + } + failPendingOperations() + val cbs = _onDisconnectCallbacks.getAndSet(persistentListOf()) + for (cb in cbs) runUserCallback { cb(this@DbConnection, reason) } + clientCache.clear() + scope.cancel() + } + + /** + * Fail all in-flight operations on disconnect. + * Clears callback maps so captured lambdas can be GC'd, and marks all + * subscription handles as ENDED so callers don't try to use stale handles. + */ + private suspend fun failPendingOperations() { + val pendingReducers = reducerCallbacks.getAndSet(persistentHashMapOf()) + reducerCallInfo.getAndSet(persistentHashMapOf()) + if (pendingReducers.isNotEmpty()) { + Logger.warn { "Discarding ${pendingReducers.size} pending reducer callback(s) due to disconnect" } + } + + val pendingProcedures = procedureCallbacks.getAndSet(persistentHashMapOf()) + if (pendingProcedures.isNotEmpty()) { + Logger.warn { "Failing ${pendingProcedures.size} pending procedure callback(s) due to disconnect" } + val errorMsg = "Connection closed before procedure result was received" + for ((requestId, cb) in pendingProcedures) { + val procedureEvent = ProcedureEvent( + timestamp = Timestamp.UNIX_EPOCH, + status = ProcedureStatus.InternalError(errorMsg), + callerIdentity = identity ?: Identity.zero(), + callerConnectionId = connectionId, + totalHostExecutionDuration = TimeDuration(Duration.ZERO), + requestId = requestId, + ) + val ctx = EventContext.Procedure( + id = nextEventId(), + connection = this, + event = procedureEvent, + ) + val resultMsg = ServerMessage.ProcedureResultMsg( + status = ProcedureStatus.InternalError(errorMsg), + timestamp = Timestamp.UNIX_EPOCH, + totalHostExecutionDuration = TimeDuration(Duration.ZERO), + requestId = requestId, + ) + runUserCallback { cb.invoke(ctx, resultMsg) } + } + } + + val pendingQueries = oneOffQueryCallbacks.getAndSet(persistentHashMapOf()) + if (pendingQueries.isNotEmpty()) { + Logger.warn { "Failing ${pendingQueries.size} pending one-off query callback(s) due to disconnect" } + val errorResult: SdkResult = SdkResult.Failure(QueryError.Disconnected) + for ((_, cb) in pendingQueries) { + runUserCallback { cb.invoke(errorResult) } + } + } + + querySetIdToRequestId.getAndSet(persistentHashMapOf()) + + val pendingSubs = subscriptions.getAndSet(persistentHashMapOf()) + for ((_, handle) in pendingSubs) { + handle.markEnded() + } + } + + // --- Subscription Builder --- + + public override fun subscriptionBuilder(): SubscriptionBuilder = SubscriptionBuilder(this) + + + // --- Subscriptions --- + + /** + * Subscribe to a set of SQL queries. + * Returns a SubscriptionHandle to track the subscription lifecycle. + */ + public override fun subscribe( + queries: List, + onApplied: List<(EventContext.SubscribeApplied) -> Unit>, + onError: List<(EventContext.Error, SubscriptionError) -> Unit>, + ): SubscriptionHandle { + val requestId = stats.subscriptionRequestTracker.startTrackingRequest() + val querySetId = QuerySetId(_nextQuerySetId.incrementAndGet().toUInt()) + val handle = SubscriptionHandle( + querySetId, + queries, + connection = this, + onAppliedCallbacks = onApplied, + onErrorCallbacks = onError + ) + subscriptions.update { it.put(querySetId.id, handle) } + querySetIdToRequestId.update { it.put(querySetId.id, requestId) } + + val message = ClientMessage.Subscribe( + requestId = requestId, + querySetId = querySetId, + queryStrings = queries, + ) + Logger.debug { "Subscribing with ${queries.size} queries (requestId=$requestId)" } + if (!sendMessage(message)) { + subscriptions.update { it.remove(querySetId.id) } + querySetIdToRequestId.update { it.remove(querySetId.id) } + stats.subscriptionRequestTracker.finishTrackingRequest(requestId) + handle.markEnded() + } + return handle + } + + public override fun subscribe(vararg queries: String): SubscriptionHandle = + subscribe(queries.toList()) + + internal fun unsubscribe(handle: SubscriptionHandle, flags: UnsubscribeFlags) { + val requestId = stats.subscriptionRequestTracker.startTrackingRequest() + val message = ClientMessage.Unsubscribe( + requestId = requestId, + querySetId = handle.querySetId, + flags = flags, + ) + if (!sendMessage(message)) { + stats.subscriptionRequestTracker.finishTrackingRequest(requestId) + } + } + + // --- Reducers --- + + /** + * Call a reducer on the server. + * The encodedArgs should be BSATN-encoded reducer arguments. + * The typedArgs is the typed args object stored for the event context. + */ + @InternalSpacetimeApi + public fun callReducer( + reducerName: String, + encodedArgs: ByteArray, + typedArgs: A, + callback: ((EventContext.Reducer) -> Unit)? = null, + flags: UByte = 0u, + ): UInt { + val requestId = stats.reducerRequestTracker.startTrackingRequest(reducerName) + if (callback != null) { + @Suppress("UNCHECKED_CAST") + reducerCallbacks.update { + it.put( + requestId, + callback as (EventContext.Reducer<*>) -> Unit + ) + } + } + reducerCallInfo.update { it.put(requestId, ReducerCallInfo(reducerName, typedArgs as Any)) } + val message = ClientMessage.CallReducer( + requestId = requestId, + flags = flags, + reducer = reducerName, + args = encodedArgs, + ) + Logger.debug { "Calling reducer '$reducerName' (requestId=$requestId)" } + if (!sendMessage(message)) { + reducerCallbacks.update { it.remove(requestId) } + reducerCallInfo.update { it.remove(requestId) } + stats.reducerRequestTracker.finishTrackingRequest(requestId) + } + return requestId + } + + // --- Procedures --- + + /** + * Call a procedure on the server. + * The args should be BSATN-encoded procedure arguments. + */ + @InternalSpacetimeApi + public fun callProcedure( + procedureName: String, + args: ByteArray, + callback: ((EventContext.Procedure, ServerMessage.ProcedureResultMsg) -> Unit)? = null, + flags: UByte = 0u, + ): UInt { + val requestId = stats.procedureRequestTracker.startTrackingRequest(procedureName) + if (callback != null) { + procedureCallbacks.update { it.put(requestId, callback) } + } + val message = ClientMessage.CallProcedure( + requestId = requestId, + flags = flags, + procedure = procedureName, + args = args, + ) + Logger.debug { "Calling procedure '$procedureName' (requestId=$requestId)" } + if (!sendMessage(message)) { + procedureCallbacks.update { it.remove(requestId) } + stats.procedureRequestTracker.finishTrackingRequest(requestId) + } + return requestId + } + + // --- One-Off Queries --- + + /** + * Execute a one-off SQL query against the database. + * The result callback receives the query result or error. + */ + public override fun oneOffQuery( + queryString: String, + callback: (SdkResult) -> Unit, + ): UInt { + val requestId = stats.oneOffRequestTracker.startTrackingRequest() + oneOffQueryCallbacks.update { it.put(requestId, callback) } + val message = ClientMessage.OneOffQuery( + requestId = requestId, + queryString = queryString, + ) + Logger.debug { "Executing one-off query (requestId=$requestId)" } + if (!sendMessage(message)) { + oneOffQueryCallbacks.update { it.remove(requestId) } + stats.oneOffRequestTracker.finishTrackingRequest(requestId) + } + return requestId + } + + /** + * Execute a one-off SQL query against the database, suspending until the result is available. + * + * @param timeout maximum time to wait for a response. Defaults to [Duration.INFINITE]. + * Throws [kotlinx.coroutines.TimeoutCancellationException] if exceeded. + */ + public override suspend fun oneOffQuery( + queryString: String, + timeout: Duration, + ): SdkResult { + suspend fun await(): SdkResult = + suspendCancellableCoroutine { cont -> + val requestId = oneOffQuery(queryString) { result -> + cont.resume(result) + } + cont.invokeOnCancellation { + oneOffQueryCallbacks.update { it.remove(requestId) } + } + } + return if (timeout.isInfinite()) await() else withTimeout(timeout) { await() } + } + + // --- Internal --- + + private fun sendMessage(message: ClientMessage): Boolean { + val result = sendChannel.trySend(message) + if (result.isFailure) { + Logger.warn { "Cannot send message: connection is not active" } + return false + } + return true + } + + private suspend fun processMessage(message: ServerMessage) { + when (message) { + is ServerMessage.InitialConnection -> { + // Validate identity consistency + val currentIdentity = identity + if (currentIdentity != null && currentIdentity != message.identity) { + val error = IllegalStateException( + "Server returned unexpected identity: ${message.identity}, expected: $currentIdentity" + ) + for (cb in _onConnectErrorCallbacks.value) runUserCallback { cb(this, error) } + // Throw so the receive loop's catch block transitions to CLOSED + // and cleans up resources. Without this, the connection stays in + // CONNECTED state with no identity — an inconsistent half-initialized state. + throw error + } + + _identity.value = message.identity + _connectionId.value = message.connectionId + if (token == null && message.token.isNotEmpty()) { + token = message.token + } + Logger.info { "Connected with identity=${message.identity}" } + for (cb in _onConnectCallbacks) runUserCallback { cb(this, message.identity, message.token) } + } + + is ServerMessage.SubscribeApplied -> { + val handle = subscriptions.value[message.querySetId.id] ?: return + val ctx = EventContext.SubscribeApplied(id = nextEventId(), connection = this) + var subRequestId: UInt? = null + querySetIdToRequestId.getAndUpdate { map -> + subRequestId = map[message.querySetId.id] + map.remove(message.querySetId.id) + } + subRequestId?.let { stats.subscriptionRequestTracker.finishTrackingRequest(it) } + + // Inserts only — no pre-apply phase needed + val callbacks = mutableListOf() + for (tableRows in message.rows.tables) { + val table = clientCache.getUntypedTable(tableRows.table) ?: continue + callbacks.addAll(table.applyInserts(ctx, tableRows.rows)) + } + + for (cb in callbacks) runUserCallback { cb.invoke() } + handle.handleApplied(ctx) + } + + is ServerMessage.UnsubscribeApplied -> { + val handle = subscriptions.value[message.querySetId.id] ?: return + val ctx = EventContext.UnsubscribeApplied(id = nextEventId(), connection = this) + + val callbacks = mutableListOf() + if (message.rows != null) { + // Parse: decode all rows once + val parsed = message.rows.tables.mapNotNull { tableRows -> + val table = clientCache.getUntypedTable(tableRows.table) ?: return@mapNotNull null + table to table.parseDeletes(tableRows.rows) + } + // Phase 1: PreApply ALL tables (fire onBeforeDelete before mutations) + for ((table, data) in parsed) { + table.preApplyDeletes(ctx, data) + } + // Phase 2: Apply ALL tables (mutate + collect post-callbacks) + for ((table, data) in parsed) { + callbacks.addAll(table.applyDeletes(ctx, data)) + } + } + + subscriptions.update { it.remove(message.querySetId.id) } + handle.handleEnd(ctx) + // Phase 3: Fire post-mutation callbacks + for (cb in callbacks) runUserCallback { cb.invoke() } + } + + is ServerMessage.SubscriptionError -> { + val handle = subscriptions.value[message.querySetId.id] ?: run { + Logger.warn { "Received SubscriptionError for unknown querySetId=${message.querySetId.id}" } + return + } + val subError = SubscriptionError.ServerError(message.error) + val ctx = EventContext.Error(id = nextEventId(), connection = this, error = Exception(message.error)) + Logger.error { "Subscription error: ${message.error}" } + var subRequestId: UInt? = null + querySetIdToRequestId.getAndUpdate { map -> + subRequestId = map[message.querySetId.id] + map.remove(message.querySetId.id) + } + subRequestId?.let { stats.subscriptionRequestTracker.finishTrackingRequest(it) } + + if (message.requestId == null) { + handle.handleError(ctx, subError) + disconnect(Exception(message.error)) + return + } + + handle.handleError(ctx, subError) + subscriptions.update { it.remove(message.querySetId.id) } + } + + is ServerMessage.TransactionUpdateMsg -> { + val ctx = EventContext.Transaction(id = nextEventId(), connection = this) + val callbacks = applyTransactionUpdate(ctx, message.update) + for (cb in callbacks) runUserCallback { cb.invoke() } + } + + is ServerMessage.ReducerResultMsg -> { + val result = message.result + var info: ReducerCallInfo? = null + reducerCallInfo.getAndUpdate { map -> + info = map[message.requestId] + map.remove(message.requestId) + } + stats.reducerRequestTracker.finishTrackingRequest(message.requestId) + val callerIdentity = identity ?: run { + Logger.error { "Received ReducerResultMsg before identity was set" } + reducerCallbacks.update { it.remove(message.requestId) } + return + } + val callerConnId = connectionId + val capturedInfo = info + + when (result) { + is ReducerOutcome.Ok -> { + val ctx = if (capturedInfo != null) { + EventContext.Reducer( + id = nextEventId(), + connection = this, + timestamp = message.timestamp, + reducerName = capturedInfo.name, + args = capturedInfo.typedArgs, + status = Status.Committed, + callerIdentity = callerIdentity, + callerConnectionId = callerConnId, + ) + } else { + EventContext.UnknownTransaction(id = nextEventId(), connection = this) + } + val callbacks = applyTransactionUpdate(ctx, result.transactionUpdate) + for (cb in callbacks) runUserCallback { cb.invoke() } + + if (ctx is EventContext.Reducer<*>) { + fireReducerCallbacks(message.requestId, ctx) + } + } + + is ReducerOutcome.OkEmpty -> { + if (capturedInfo != null) { + val ctx = EventContext.Reducer( + id = nextEventId(), + connection = this, + timestamp = message.timestamp, + reducerName = capturedInfo.name, + args = capturedInfo.typedArgs, + status = Status.Committed, + callerIdentity = callerIdentity, + callerConnectionId = callerConnId, + ) + fireReducerCallbacks(message.requestId, ctx) + } + } + + is ReducerOutcome.Err -> { + val errorMsg = decodeReducerError(result.error) + Logger.warn { "Reducer '${capturedInfo?.name}' failed: $errorMsg" } + if (capturedInfo != null) { + val ctx = EventContext.Reducer( + id = nextEventId(), + connection = this, + timestamp = message.timestamp, + reducerName = capturedInfo.name, + args = capturedInfo.typedArgs, + status = Status.Failed(errorMsg), + callerIdentity = callerIdentity, + callerConnectionId = callerConnId, + ) + fireReducerCallbacks(message.requestId, ctx) + } + } + + is ReducerOutcome.InternalError -> { + Logger.error { "Reducer '${capturedInfo?.name}' internal error: ${result.message}" } + if (capturedInfo != null) { + val ctx = EventContext.Reducer( + id = nextEventId(), + connection = this, + timestamp = message.timestamp, + reducerName = capturedInfo.name, + args = capturedInfo.typedArgs, + status = Status.Failed(result.message), + callerIdentity = callerIdentity, + callerConnectionId = callerConnId, + ) + fireReducerCallbacks(message.requestId, ctx) + } + } + } + } + + is ServerMessage.ProcedureResultMsg -> { + stats.procedureRequestTracker.finishTrackingRequest(message.requestId) + var cb: ((EventContext.Procedure, ServerMessage.ProcedureResultMsg) -> Unit)? = null + procedureCallbacks.getAndUpdate { map -> + cb = map[message.requestId] + map.remove(message.requestId) + } + val procIdentity = identity ?: run { + Logger.error { "Received ProcedureResultMsg before identity was set" } + return + } + val procConnId = connectionId + cb?.let { + val procedureEvent = ProcedureEvent( + timestamp = message.timestamp, + status = message.status, + callerIdentity = procIdentity, + callerConnectionId = procConnId, + totalHostExecutionDuration = message.totalHostExecutionDuration, + requestId = message.requestId, + ) + val ctx = EventContext.Procedure( + id = nextEventId(), + connection = this, + event = procedureEvent + ) + runUserCallback { it.invoke(ctx, message) } + } + } + + is ServerMessage.OneOffQueryResult -> { + stats.oneOffRequestTracker.finishTrackingRequest(message.requestId) + var cb: ((SdkResult) -> Unit)? = null + oneOffQueryCallbacks.getAndUpdate { map -> + cb = map[message.requestId] + map.remove(message.requestId) + } + cb?.let { callback -> + val sdkResult: SdkResult = when (val r = message.result) { + is QueryResult.Ok -> SdkResult.Success(OneOffQueryData(r.rows.tables.size)) + is QueryResult.Err -> SdkResult.Failure(QueryError.ServerError(r.error)) + } + runUserCallback { callback.invoke(sdkResult) } + } + } + } + } + + private suspend fun fireReducerCallbacks(requestId: UInt, ctx: EventContext.Reducer<*>) { + var cb: ((EventContext.Reducer<*>) -> Unit)? = null + reducerCallbacks.getAndUpdate { map -> + cb = map[requestId] + map.remove(requestId) + } + cb?.let { runUserCallback { it.invoke(ctx) } } + moduleDescriptor?.let { runUserCallback { it.handleReducerEvent(this, ctx) } } + } + + private fun applyTransactionUpdate( + ctx: EventContext, + update: TransactionUpdate, + ): List { + // Parse: decode all rows once + val allUpdates = mutableListOf, ParsedTableData>>() + for (querySetUpdate in update.querySets) { + for (tableUpdate in querySetUpdate.tables) { + val table = clientCache.getUntypedTable(tableUpdate.tableName) ?: continue + for (rows in tableUpdate.rows) { + allUpdates.add(table to table.parseUpdate(rows)) + } + } + } + + // Phase 1: PreApply ALL tables (fire onBeforeDelete before any mutations) + for ((table, parsed) in allUpdates) { + table.preApplyUpdate(ctx, parsed) + } + + // Phase 2: Apply ALL tables (mutate + collect post-callbacks) + val allCallbacks = mutableListOf() + for ((table, parsed) in allUpdates) { + allCallbacks.addAll(table.applyUpdate(ctx, parsed)) + } + + return allCallbacks + } + + // --- Builder --- + + /** Fluent builder for configuring and creating a [DbConnection]. */ + public class Builder { + private var uri: String? = null + private var nameOrAddress: String? = null + private var authToken: String? = null + private var compression: CompressionMode = defaultCompressionMode + private var lightMode: Boolean = false + private var confirmedReads: Boolean? = null + private val onConnectCallbacks = mutableListOf<(DbConnectionView, Identity, String) -> Unit>() + private val onDisconnectCallbacks = mutableListOf<(DbConnectionView, Throwable?) -> Unit>() + private val onConnectErrorCallbacks = mutableListOf<(DbConnectionView, Throwable) -> Unit>() + private var module: ModuleDescriptor? = null + private var callbackDispatcher: CoroutineDispatcher? = null + private var httpClient: HttpClient? = null + + /** + * Provide the [HttpClient] for the WebSocket transport. + * Must have the Ktor WebSockets plugin installed. + */ + public fun withHttpClient(client: HttpClient): Builder = apply { httpClient = client } + + /** Sets the SpacetimeDB server URI (e.g. `http://localhost:3000`). */ + public fun withUri(uri: String): Builder = apply { this.uri = uri } + /** Sets the database name or address to connect to. */ + public fun withDatabaseName(nameOrAddress: String): Builder = + apply { this.nameOrAddress = nameOrAddress } + + /** Sets the authentication token, or `null` for anonymous connections. */ + public fun withToken(token: String?): Builder = apply { authToken = token } + /** Sets the compression mode for the WebSocket connection. */ + public fun withCompression(compression: CompressionMode): Builder = + apply { this.compression = compression } + + /** Enables or disables light mode (reduced initial data transfer). */ + public fun withLightMode(lightMode: Boolean): Builder = apply { this.lightMode = lightMode } + /** Enables or disables confirmed reads from the server. */ + public fun withConfirmedReads(confirmed: Boolean): Builder = apply { confirmedReads = confirmed } + + /** + * Set a [CoroutineDispatcher] for user callbacks (onInsert, onDelete, onUpdate, + * onConnect, reducer callbacks, etc.). When set, all user callbacks are dispatched + * via [withContext] to this dispatcher. When not set (the default), callbacks run + * on the receive-loop thread ([kotlinx.coroutines.Dispatchers.Default]). + * + * Android example: `withCallbackDispatcher(Dispatchers.Main)` + */ + public fun withCallbackDispatcher(dispatcher: CoroutineDispatcher): Builder = + apply { this.callbackDispatcher = dispatcher } + + /** + * Register the generated module bindings. + * The generated `withModuleBindings()` extension calls this automatically. + */ + @InternalSpacetimeApi + public fun withModule(descriptor: ModuleDescriptor): Builder = apply { module = descriptor } + + /** Registers a callback invoked when the connection is established. */ + public fun onConnect(cb: (DbConnectionView, Identity, String) -> Unit): Builder = + apply { onConnectCallbacks.add(cb) } + + /** Registers a callback invoked when the connection is closed. */ + public fun onDisconnect(cb: (DbConnectionView, Throwable?) -> Unit): Builder = + apply { onDisconnectCallbacks.add(cb) } + + /** Registers a callback invoked when a connection attempt fails. */ + public fun onConnectError(cb: (DbConnectionView, Throwable) -> Unit): Builder = + apply { onConnectErrorCallbacks.add(cb) } + + /** Builds and connects the [DbConnection]. Suspends until the WebSocket handshake completes. */ + public suspend fun build(): DbConnection { + module?.let { ensureMinimumVersion(it.cliVersion) } + require(compression in availableCompressionModes) { + "Compression mode $compression is not supported on this platform. " + + "Available modes: $availableCompressionModes" + } + val resolvedUri = requireNotNull(uri) { "URI is required" } + val resolvedModule = requireNotNull(nameOrAddress) { "Module name is required" } + val resolvedClient = requireNotNull(httpClient) { "HttpClient is required. Call withHttpClient() on the builder." } + val clientConnectionId = ConnectionId.random() + val stats = Stats() + + val transport = SpacetimeTransport( + client = resolvedClient, + baseUrl = resolvedUri, + nameOrAddress = resolvedModule, + connectionId = clientConnectionId, + authToken = authToken, + compression = compression, + lightMode = lightMode, + confirmedReads = confirmedReads, + ) + + val scope = CoroutineScope(SupervisorJob()) + + val conn = DbConnection( + transport = transport, + scope = scope, + onConnectCallbacks = onConnectCallbacks, + onDisconnectCallbacks = onDisconnectCallbacks, + onConnectErrorCallbacks = onConnectErrorCallbacks, + clientConnectionId = clientConnectionId, + stats = stats, + moduleDescriptor = module, + callbackDispatcher = callbackDispatcher, + ) + + module?.let { + it.registerTables(conn.clientCache) + val accessors = it.createAccessors(conn) + conn.moduleTables = accessors.tables + conn.moduleReducers = accessors.reducers + conn.moduleProcedures = accessors.procedures + } + conn.connect() + + return conn + } + + } +} + +/** + * Executes [block] with this [DbConnection], then calls [disconnect] when done. + * Ensures cleanup even if [block] throws or the coroutine is cancelled. + */ +public suspend inline fun DbConnection.use(block: (DbConnection) -> R): R { + try { + return block(this) + } finally { + withContext(NonCancellable) { + disconnect() + } + } +} + +/** Marker interface for generated table accessors. */ +public interface ModuleTables + +/** Marker interface for generated reducer accessors. */ +public interface ModuleReducers + +/** Marker interface for generated procedure accessors. */ +public interface ModuleProcedures + +/** Accessor instances created by [ModuleDescriptor.createAccessors]. */ +@InternalSpacetimeApi +public data class ModuleAccessors( + public val tables: ModuleTables, + public val reducers: ModuleReducers, + public val procedures: ModuleProcedures, +) + +/** + * Describes a generated SpacetimeDB module's bindings. + * Implemented by the generated code to register tables and dispatch reducer events. + */ +@InternalSpacetimeApi +public interface ModuleDescriptor { + public val cliVersion: String + /** Names of persistent (subscribable) tables. Event tables are excluded. */ + public val subscribableTableNames: List + public fun registerTables(cache: ClientCache) + public fun createAccessors(conn: DbConnection): ModuleAccessors + public fun handleReducerEvent(conn: DbConnection, ctx: EventContext.Reducer<*>) +} + +private val MINIMUM_CLI_VERSION = intArrayOf(2, 0, 0) + +private fun parseVersion(version: String): IntArray { + val parts = version.split("-")[0].split(".") + return intArrayOf( + parts.getOrNull(0)?.toIntOrNull() ?: 0, + parts.getOrNull(1)?.toIntOrNull() ?: 0, + parts.getOrNull(2)?.toIntOrNull() ?: 0, + ) +} + +private fun ensureMinimumVersion(cliVersion: String) { + val parsed = parseVersion(cliVersion) + for (i in 0..2) { + if (parsed[i] > MINIMUM_CLI_VERSION[i]) return + if (parsed[i] < MINIMUM_CLI_VERSION[i]) { + val min = MINIMUM_CLI_VERSION.joinToString(".") + throw IllegalStateException( + "Module bindings were generated with spacetimedb cli $cliVersion, " + + "but this SDK requires at least $min. " + + "Regenerate bindings with an updated CLI: spacetime generate" + ) + } + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Errors.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Errors.kt new file mode 100644 index 00000000000..ab0e1197c96 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Errors.kt @@ -0,0 +1,23 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +/** Errors from a one-off SQL query. */ +public sealed interface QueryError : SdkError { + /** The server rejected the query or returned an error. */ + public data class ServerError(val message: String) : QueryError + /** The connection was closed before the query result was received. */ + public data object Disconnected : QueryError +} + +/** Errors from a procedure call. */ +public sealed interface ProcedureError : SdkError { + /** The server reported an internal error executing the procedure. */ + public data class InternalError(val message: String) : ProcedureError + /** The connection was closed before the procedure result was received. */ + public data object Disconnected : ProcedureError +} + +/** Errors from a subscription. */ +public sealed interface SubscriptionError : SdkError { + /** The server rejected the subscription query. */ + public data class ServerError(val message: String) : SubscriptionError +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/EventContext.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/EventContext.kt new file mode 100644 index 00000000000..fe94ef4cddc --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/EventContext.kt @@ -0,0 +1,160 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ProcedureStatus +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.TimeDuration +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import kotlin.time.Duration + +/** + * Reducer call status. + */ +public sealed interface Status { + /** The reducer committed its transaction successfully. */ + public data object Committed : Status + /** The reducer failed with the given error [message]. */ + public data class Failed(val message: String) : Status +} + +/** + * Procedure event data for procedure-specific callbacks. + */ +public data class ProcedureEvent( + val timestamp: Timestamp, + val status: ProcedureStatus, + val callerIdentity: Identity, + val callerConnectionId: ConnectionId?, + val totalHostExecutionDuration: TimeDuration, + val requestId: UInt, +) + +/** + * Scoped view of [DbConnection] exposed to callback code via [EventContext]. + * Restricts access to the subset of operations appropriate inside event handlers. + * + * Generated code adds extension properties (`db`, `reducers`, `procedures`) + * on this interface for typed access to module bindings. + */ +public interface DbConnectionView { + /** The identity assigned by the server, or `null` before connection. */ + public val identity: Identity? + /** The connection ID assigned by the server, or `null` before connection. */ + public val connectionId: ConnectionId? + /** Whether the connection is currently active. */ + public val isActive: Boolean + /** Generated table accessors, or `null` if no module bindings were registered. */ + public val moduleTables: ModuleTables? + /** Generated reducer accessors, or `null` if no module bindings were registered. */ + public val moduleReducers: ModuleReducers? + /** Generated procedure accessors, or `null` if no module bindings were registered. */ + public val moduleProcedures: ModuleProcedures? + + /** Creates a new [SubscriptionBuilder] for configuring and subscribing to queries. */ + public fun subscriptionBuilder(): SubscriptionBuilder + /** Subscribes to the given SQL [queries] with optional callbacks. */ + public fun subscribe( + queries: List, + onApplied: List<(EventContext.SubscribeApplied) -> Unit> = emptyList(), + onError: List<(EventContext.Error, SubscriptionError) -> Unit> = emptyList(), + ): SubscriptionHandle + /** Subscribes to the given SQL [queries]. */ + public fun subscribe(vararg queries: String): SubscriptionHandle + + /** Executes a one-off SQL query with a callback for the result. */ + public fun oneOffQuery( + queryString: String, + callback: (SdkResult) -> Unit, + ): UInt + /** Executes a one-off SQL query, suspending until the result is available. */ + public suspend fun oneOffQuery( + queryString: String, + timeout: Duration = Duration.INFINITE, + ): SdkResult + + /** Disconnects from SpacetimeDB, optionally providing a [reason]. */ + public suspend fun disconnect(reason: Throwable? = null) + + /** Registers a callback invoked when the connection is closed. */ + public fun onDisconnect(cb: (DbConnectionView, Throwable?) -> Unit) + /** Removes a previously registered disconnect callback. */ + public fun removeOnDisconnect(cb: (DbConnectionView, Throwable?) -> Unit) + /** Registers a callback invoked when a connection attempt fails. */ + public fun onConnectError(cb: (DbConnectionView, Throwable) -> Unit) + /** Removes a previously registered connect-error callback. */ + public fun removeOnConnectError(cb: (DbConnectionView, Throwable) -> Unit) +} + +/** + * Context passed to callbacks. Sealed interface with specialized subtypes + * so callbacks receive only the fields relevant to their event type. + * + * Subtypes are plain classes (not data classes) because [connection] is a + * mutable handle, not value data — it should not participate in equals/hashCode. + */ +public sealed interface EventContext { + /** Unique identifier for this event. */ + public val id: String + /** The connection that produced this event. */ + public val connection: DbConnectionView + + /** Fired when a subscription's initial rows have been applied to the client cache. */ + public class SubscribeApplied( + override val id: String, + override val connection: DbConnectionView, + ) : EventContext + + /** Fired when an unsubscription has been confirmed by the server. */ + public class UnsubscribeApplied( + override val id: String, + override val connection: DbConnectionView, + ) : EventContext + + /** Fired when a server-side transaction update has been applied. */ + public class Transaction( + override val id: String, + override val connection: DbConnectionView, + ) : EventContext + + /** Fired when a reducer result is received, carrying the typed arguments and status. */ + public class Reducer( + override val id: String, + override val connection: DbConnection, + public val timestamp: Timestamp, + public val reducerName: String, + public val args: A, + public val status: Status, + public val callerIdentity: Identity, + public val callerConnectionId: ConnectionId?, + ) : EventContext + + /** Fired when a procedure result is received. */ + public class Procedure( + override val id: String, + override val connection: DbConnection, + public val event: ProcedureEvent, + ) : EventContext + + /** Fired when an error occurs, such as a subscription error. */ + public class Error( + override val id: String, + override val connection: DbConnection, + public val error: Throwable, + ) : EventContext + + /** + * A reducer result was received but no matching [ReducerCallInfo] was found. + * This is defensive — it can happen if the reducer was called from another client + * or if the call info was lost (e.g. reconnect). + */ + public class UnknownTransaction( + override val id: String, + override val connection: DbConnectionView, + ) : EventContext +} + +/** Test-only [EventContext] stub. Not part of the public API. */ +internal class StubEventContext(override val id: String = "test") : EventContext { + override val connection: DbConnectionView + get() = error("StubEventContext.connection should not be accessed in unit tests") +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Index.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Index.kt new file mode 100644 index 00000000000..5e3a16d79d6 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Index.kt @@ -0,0 +1,89 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.PersistentSet +import kotlinx.collections.immutable.persistentHashMapOf +import kotlinx.collections.immutable.persistentHashSetOf + +/** + * A client-side unique index backed by an atomic persistent map. + * Provides O(1) lookup by the indexed column value. + * Thread-safe: reads return a consistent snapshot. + * + * Subscribes to the TableCache's internal insert/delete hooks + * to stay synchronized with the cache contents. + */ +public class UniqueIndex( + tableCache: TableCache, + private val keyExtractor: (Row) -> Col, +) { + private val _cache = atomic(persistentHashMapOf()) + + init { + // Register listeners before populating so rows inserted concurrently + // cause a CAS retry in the population update, picking them up via iter(). + tableCache.addInternalInsertListener { row -> + _cache.update { it.put(keyExtractor(row), row) } + } + tableCache.addInternalDeleteListener { row -> + _cache.update { it.remove(keyExtractor(row)) } + } + _cache.update { + val builder = it.builder() + for (row in tableCache.iter()) { + builder[keyExtractor(row)] = row + } + builder.build() + } + } + + /** Returns the row matching [value], or `null` if no match. */ + public fun find(value: Col): Row? = _cache.value[value] +} + +/** + * A client-side non-unique index backed by an atomic persistent map of persistent sets. + * Provides lookup for all rows matching a given column value. + * Thread-safe: reads return a consistent snapshot. + * + * Uses [PersistentSet] (not List) so that add is idempotent — if the listener + * and the population loop both add the same row during init, no duplicate is produced. + * + * Subscribes to the TableCache's internal insert/delete hooks + * to stay synchronized with the cache contents. + */ +public class BTreeIndex( + tableCache: TableCache, + private val keyExtractor: (Row) -> Col, +) { + private val _cache = atomic(persistentHashMapOf>()) + + init { + tableCache.addInternalInsertListener { row -> + val key = keyExtractor(row) + _cache.update { current -> + current.put(key, (current[key] ?: persistentHashSetOf()).add(row)) + } + } + tableCache.addInternalDeleteListener { row -> + val key = keyExtractor(row) + _cache.update { current -> + val set = current[key] ?: return@update current + val updated = set.remove(row) + if (updated.isEmpty()) current.remove(key) else current.put(key, updated) + } + } + _cache.update { current -> + val builder = current.builder() + for (row in tableCache.iter()) { + val key = keyExtractor(row) + builder[key] = (builder[key] ?: persistentHashSetOf()).add(row) + } + builder.build() + } + } + + /** Returns all rows matching [value], or an empty set if none. */ + public fun filter(value: Col): Set = _cache.value[value] ?: emptySet() +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Int128.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Int128.kt new file mode 100644 index 00000000000..83bbf1c9bdf --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Int128.kt @@ -0,0 +1,21 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import kotlin.jvm.JvmInline + +/** A signed 128-bit integer, backed by [BigInteger]. */ +@JvmInline +public value class Int128(public val value: BigInteger) : Comparable { + /** Encodes this value to BSATN. */ + public fun encode(writer: BsatnWriter): Unit = writer.writeI128(value) + override fun compareTo(other: Int128): Int = value.compareTo(other.value) + override fun toString(): String = value.toString() + + public companion object { + /** Decodes an [Int128] from BSATN. */ + public fun decode(reader: BsatnReader): Int128 = Int128(reader.readI128()) + /** A zero-valued [Int128]. */ + public val ZERO: Int128 = Int128(BigInteger.ZERO) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Int256.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Int256.kt new file mode 100644 index 00000000000..04c5f8b6680 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Int256.kt @@ -0,0 +1,21 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import kotlin.jvm.JvmInline + +/** A signed 256-bit integer, backed by [BigInteger]. */ +@JvmInline +public value class Int256(public val value: BigInteger) : Comparable { + /** Encodes this value to BSATN. */ + public fun encode(writer: BsatnWriter): Unit = writer.writeI256(value) + override fun compareTo(other: Int256): Int = value.compareTo(other.value) + override fun toString(): String = value.toString() + + public companion object { + /** Decodes an [Int256] from BSATN. */ + public fun decode(reader: BsatnReader): Int256 = Int256(reader.readI256()) + /** A zero-valued [Int256]. */ + public val ZERO: Int256 = Int256(BigInteger.ZERO) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/InternalApi.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/InternalApi.kt new file mode 100644 index 00000000000..f62496c01ec --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/InternalApi.kt @@ -0,0 +1,13 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +/** + * Marks declarations that are internal to the SpacetimeDB SDK and generated code. + * Using them from user code is unsupported and may break without notice. + */ +@RequiresOptIn( + message = "This is internal to the SpacetimeDB SDK and generated code. Do not use directly.", + level = RequiresOptIn.Level.ERROR, +) +@Retention(AnnotationRetention.BINARY) +@Target(AnnotationTarget.CLASS, AnnotationTarget.CONSTRUCTOR, AnnotationTarget.FUNCTION, AnnotationTarget.PROPERTY) +public annotation class InternalSpacetimeApi diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Logger.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Logger.kt new file mode 100644 index 00000000000..af229111e80 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Logger.kt @@ -0,0 +1,108 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import kotlinx.atomicfu.atomic + +/** + * Log severity levels for the SpacetimeDB SDK. + */ +public enum class LogLevel { + /** Unrecoverable errors (exceptions with stack traces). */ + EXCEPTION, + /** Errors that may be recoverable. */ + ERROR, + /** Potentially harmful situations. */ + WARN, + /** Informational messages about connection lifecycle. */ + INFO, + /** Detailed diagnostic information. */ + DEBUG, + /** Fine-grained tracing of internal operations. */ + TRACE; + + internal fun shouldLog(threshold: LogLevel): Boolean = this.ordinal <= threshold.ordinal +} + +/** + * Handler for log output. Implement to route logs to a custom destination. + */ +public fun interface LogHandler { + /** Emits a log message at the given [level]. */ + public fun log(level: LogLevel, message: String) +} + +private val SENSITIVE_KEYS = listOf("token", "authtoken", "auth_token", "password", "secret", "credential", "api_key", "apikey", "bearer") + +private val SENSITIVE_PATTERNS: List by lazy { + SENSITIVE_KEYS.map { key -> + Regex("""($key\s*[=:]\s*)\S+""", RegexOption.IGNORE_CASE) + } +} + +/** + * Redact sensitive key-value pairs from a message string. + */ +private fun redactSensitive(message: String): String { + val lower = message.lowercase() + if (SENSITIVE_KEYS.none { it in lower }) return message + var result = message + for (pattern in SENSITIVE_PATTERNS) { + result = result.replace(pattern, "$1[REDACTED]") + } + return result +} + +/** + * Global logger for the SpacetimeDB SDK. + * Configurable level and handler with lazy message evaluation. + */ +public object Logger { + private val _level = atomic(LogLevel.INFO) + private val _handler = atomic(LogHandler { lvl, msg -> + println("[SpacetimeDB ${lvl.name}] $msg") + }) + + /** Minimum severity level; messages below this threshold are discarded. */ + public var level: LogLevel + get() = _level.value + set(value) { _level.value = value } + + /** The active log handler. Replace to route SDK logs to your logging framework. */ + public var handler: LogHandler + get() = _handler.value + set(value) { _handler.value = value } + + /** Logs a throwable's stack trace at EXCEPTION level. */ + public fun exception(throwable: Throwable) { + if (LogLevel.EXCEPTION.shouldLog(level)) handler.log(LogLevel.EXCEPTION, redactSensitive(throwable.stackTraceToString())) + } + + /** Logs a lazily-evaluated message at EXCEPTION level. */ + public fun exception(message: () -> String) { + if (LogLevel.EXCEPTION.shouldLog(level)) handler.log(LogLevel.EXCEPTION, redactSensitive(message())) + } + + /** Logs a lazily-evaluated message at ERROR level. */ + public fun error(message: () -> String) { + if (LogLevel.ERROR.shouldLog(level)) handler.log(LogLevel.ERROR, redactSensitive(message())) + } + + /** Logs a lazily-evaluated message at WARN level. */ + public fun warn(message: () -> String) { + if (LogLevel.WARN.shouldLog(level)) handler.log(LogLevel.WARN, redactSensitive(message())) + } + + /** Logs a lazily-evaluated message at INFO level. */ + public fun info(message: () -> String) { + if (LogLevel.INFO.shouldLog(level)) handler.log(LogLevel.INFO, redactSensitive(message())) + } + + /** Logs a lazily-evaluated message at DEBUG level. */ + public fun debug(message: () -> String) { + if (LogLevel.DEBUG.shouldLog(level)) handler.log(LogLevel.DEBUG, redactSensitive(message())) + } + + /** Logs a lazily-evaluated message at TRACE level. */ + public fun trace(message: () -> String) { + if (LogLevel.TRACE.shouldLog(level)) handler.log(LogLevel.TRACE, redactSensitive(message())) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/OneOffQueryData.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/OneOffQueryData.kt new file mode 100644 index 00000000000..955050203f9 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/OneOffQueryData.kt @@ -0,0 +1,10 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +/** Success payload for a one-off SQL query result. */ +public data class OneOffQueryData( + /** Number of tables that returned rows. */ + val tableCount: Int, +) + +/** Result type for one-off SQL queries. */ +public typealias OneOffQueryResult = SdkResult diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/RemoteTable.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/RemoteTable.kt new file mode 100644 index 00000000000..cb2d148ce59 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/RemoteTable.kt @@ -0,0 +1,62 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +/** + * Sealed hierarchy for generated table handles. + * Use `is RemotePersistentTable` / `is RemoteEventTable` to distinguish at runtime. + * + * - [RemotePersistentTable]: rows are stored in the client cache; supports + * count/all/iter, onDelete, and onBeforeDelete. + * - [RemoteEventTable]: rows are NOT stored; only onInsert fires per event. + */ +public sealed interface RemoteTable { + /** Registers a callback that fires when a row is inserted. */ + public fun onInsert(cb: (EventContext, Row) -> Unit) + + /** Removes a previously registered insert callback. */ + public fun removeOnInsert(cb: (EventContext, Row) -> Unit) +} + +/** + * A generated table handle backed by a persistent (stored) table. + * Provides read access to cached rows and callbacks for inserts, deletes, and before-delete. + */ +public interface RemotePersistentTable : RemoteTable { + /** Returns the number of rows currently in the client cache for this table. */ + public fun count(): Int + + /** Returns a snapshot list of all cached rows. */ + public fun all(): List + + /** Returns a lazy sequence over all cached rows. */ + public fun iter(): Sequence + + /** Registers a callback that fires after a row is deleted. */ + public fun onDelete(cb: (EventContext, Row) -> Unit) + + /** Removes a previously registered delete callback. */ + public fun removeOnDelete(cb: (EventContext, Row) -> Unit) + + /** Registers a callback that fires before a row is deleted. */ + public fun onBeforeDelete(cb: (EventContext, Row) -> Unit) + + /** Removes a previously registered before-delete callback. */ + public fun removeOnBeforeDelete(cb: (EventContext, Row) -> Unit) +} + +/** + * A [RemotePersistentTable] whose rows have a primary key. + * Adds [onUpdate] / [removeOnUpdate] callbacks that fire when an existing row is replaced. + */ +public interface RemotePersistentTableWithPrimaryKey : RemotePersistentTable { + /** Registers a callback that fires when an existing row is replaced (old row, new row). */ + public fun onUpdate(cb: (EventContext, Row, Row) -> Unit) + + /** Removes a previously registered update callback. */ + public fun removeOnUpdate(cb: (EventContext, Row, Row) -> Unit) +} + +/** + * A generated table handle backed by an event (non-stored) table. + * Rows are not cached; only insert callbacks fire per event. + */ +public interface RemoteEventTable : RemoteTable diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SdkError.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SdkError.kt new file mode 100644 index 00000000000..c59d2315fa7 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SdkError.kt @@ -0,0 +1,8 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +/** + * Marker interface for typed SDK errors. + * All error types returned by SDK operations implement this interface, + * enabling exhaustive `when` blocks on [SdkResult.Failure]. + */ +public interface SdkError diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SdkResult.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SdkResult.kt new file mode 100644 index 00000000000..f41fccc83fb --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SdkResult.kt @@ -0,0 +1,52 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +/** + * A discriminated union representing either a successful value or a typed error. + * Unlike `kotlin.Result`, the error type [E] is preserved at compile time, + * enabling exhaustive pattern matching on error variants. + */ +public sealed interface SdkResult { + /** Successful outcome holding [data]. */ + public data class Success(val data: T) : SdkResult + /** Failed outcome holding a typed [error]. */ + public data class Failure(val error: E) : SdkResult +} + +/** Alias for operations that succeed with [Unit] or fail with [E]. */ +public typealias EmptySdkResult = SdkResult + +/** Runs [action] if this is [SdkResult.Success], returns `this` unchanged. */ +public inline fun SdkResult.onSuccess( + action: (T) -> Unit, +): SdkResult { + if (this is SdkResult.Success) action(data) + return this +} + +/** Runs [action] if this is [SdkResult.Failure], returns `this` unchanged. */ +public inline fun SdkResult.onFailure( + action: (E) -> Unit, +): SdkResult { + if (this is SdkResult.Failure) action(error) + return this +} + +/** Transforms the success value with [transform], preserving errors. */ +public inline fun SdkResult.map( + transform: (T) -> R, +): SdkResult = when (this) { + is SdkResult.Success -> SdkResult.Success(transform(data)) + is SdkResult.Failure -> this +} + +/** Returns the success value, or `null` if this is a failure. */ +public fun SdkResult.getOrNull(): T? = + (this as? SdkResult.Success)?.data + +/** Returns the error, or `null` if this is a success. */ +public fun SdkResult.errorOrNull(): E? = + (this as? SdkResult.Failure)?.error + +/** Discards the success value, preserving only the success/failure status. */ +public fun SdkResult<*, E>.asEmptyResult(): EmptySdkResult = + map { } diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SpacetimeResult.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SpacetimeResult.kt new file mode 100644 index 00000000000..7dbb46948d2 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SpacetimeResult.kt @@ -0,0 +1,12 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +/** + * A sum type representing either a successful [Ok] value or an [Err] error. + * Corresponds to `Result` in the SpacetimeDB module schema. + */ +public sealed interface SpacetimeResult { + /** Successful variant holding [value]. */ + public data class Ok(val value: T) : SpacetimeResult + /** Error variant holding [error]. */ + public data class Err(val error: E) : SpacetimeResult +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SqlFormat.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SqlFormat.kt new file mode 100644 index 00000000000..0909eaca4bd --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SqlFormat.kt @@ -0,0 +1,39 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +/** + * SQL formatting utilities for the typed query builder. + * Handles identifier quoting and literal escaping. + */ +@InternalSpacetimeApi +public object SqlFormat { + /** + * Quote a SQL identifier with double quotes, escaping internal double quotes by doubling. + * Example: `tableName` → `"tableName"`, `bad"name` → `"bad""name"` + */ + public fun quoteIdent(ident: String): String = "\"${ident.replace("\"", "\"\"")}\"" + + /** + * Format a string value as a SQL string literal with single quotes. + * Internal single quotes are escaped by doubling. + * Example: `O'Brien` → `'O''Brien'` + */ + public fun formatStringLiteral(value: String): String = "'${value.replace("'", "''")}'" + + /** + * Format a hex string as a SQL hex literal. + * Strips optional `0x` prefix and hyphens, validates all characters are hex digits. + * Example: `01020304` → `0x01020304` + */ + public fun formatHexLiteral(hex: String): String { + var cleaned = hex + if (cleaned.startsWith("0x", ignoreCase = true)) { + cleaned = cleaned.substring(2) + } + cleaned = cleaned.replace("-", "") + require(cleaned.isNotEmpty()) { "Empty hex string: $hex" } + require(cleaned.all { it in '0'..'9' || it in 'a'..'f' || it in 'A'..'F' }) { + "Invalid hex string: $hex" + } + return "0x$cleaned" + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SqlLiteral.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SqlLiteral.kt new file mode 100644 index 00000000000..cd2563b6fa1 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SqlLiteral.kt @@ -0,0 +1,117 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.SpacetimeUuid +import kotlin.jvm.JvmInline + +/** + * A type-safe wrapper around a SQL literal string. + * The type parameter [T] tracks the Kotlin type at compile time + * to ensure column comparisons are type-safe. + */ +@JvmInline +public value class SqlLiteral<@Suppress("unused") T> @InternalSpacetimeApi constructor(@property:InternalSpacetimeApi public val sql: String) + +/** + * Factory for creating [SqlLiteral] values from Kotlin types. + * + * Each method converts a native Kotlin value into its SQL literal representation. + */ +public object SqlLit { + public fun string(value: String): SqlLiteral = + SqlLiteral(SqlFormat.formatStringLiteral(value)) + + public fun bool(value: Boolean): SqlLiteral = + SqlLiteral(if (value) "TRUE" else "FALSE") + + public fun byte(value: Byte): SqlLiteral = SqlLiteral(value.toString()) + public fun ubyte(value: UByte): SqlLiteral = SqlLiteral(value.toString()) + public fun short(value: Short): SqlLiteral = SqlLiteral(value.toString()) + public fun ushort(value: UShort): SqlLiteral = SqlLiteral(value.toString()) + public fun int(value: Int): SqlLiteral = SqlLiteral(value.toString()) + public fun uint(value: UInt): SqlLiteral = SqlLiteral(value.toString()) + public fun long(value: Long): SqlLiteral = SqlLiteral(value.toString()) + public fun ulong(value: ULong): SqlLiteral = SqlLiteral(value.toString()) + public fun float(value: Float): SqlLiteral { + require(value.isFinite()) { "SQL literals do not support NaN or Infinity" } + return SqlLiteral(value.toPlainDecimalString()) + } + + public fun double(value: Double): SqlLiteral { + require(value.isFinite()) { "SQL literals do not support NaN or Infinity" } + return SqlLiteral(value.toPlainDecimalString()) + } + + public fun int128(value: Int128): SqlLiteral = SqlLiteral(value.value.toString()) + public fun uint128(value: UInt128): SqlLiteral = SqlLiteral(value.value.toString()) + public fun int256(value: Int256): SqlLiteral = SqlLiteral(value.value.toString()) + public fun uint256(value: UInt256): SqlLiteral = SqlLiteral(value.value.toString()) + + public fun identity(value: Identity): SqlLiteral = + SqlLiteral(SqlFormat.formatHexLiteral(value.toHexString())) + + public fun connectionId(value: ConnectionId): SqlLiteral = + SqlLiteral(SqlFormat.formatHexLiteral(value.toHexString())) + + public fun uuid(value: SpacetimeUuid): SqlLiteral = + SqlLiteral(SqlFormat.formatHexLiteral(value.toHexString())) +} + +/** + * Formats a Float as a plain decimal string without scientific notation. + * Uses Float.toString() to preserve original float precision (avoids float→double expansion). + */ +private fun Float.toPlainDecimalString(): String { + val s = this.toString() + if ('E' !in s && 'e' !in s) return s + return expandScientificNotation(s) +} + +/** + * Formats a Double as a plain decimal string without scientific notation. + * Handles the E/e notation that Double.toString() may produce for very large or small values. + */ +private fun Double.toPlainDecimalString(): String { + val s = this.toString() + if ('E' !in s && 'e' !in s) return s + return expandScientificNotation(s) +} + +/** Expands a scientific notation string (e.g. "1.5E-7") to plain decimal (e.g. "0.00000015"). */ +private fun expandScientificNotation(s: String): String { + val eIdx = s.indexOfFirst { it == 'E' || it == 'e' } + val mantissa = s.substring(0, eIdx) + val exponent = s.substring(eIdx + 1).toInt() + + val negative = mantissa.startsWith('-') + val absMantissa = if (negative) mantissa.substring(1) else mantissa + val dotIdx = absMantissa.indexOf('.') + val intPart = if (dotIdx >= 0) absMantissa.substring(0, dotIdx) else absMantissa + val fracPart = if (dotIdx >= 0) absMantissa.substring(dotIdx + 1) else "" + val allDigits = intPart + fracPart + val newDecimalPos = intPart.length + exponent + + val sb = StringBuilder() + if (negative) sb.append('-') + + when { + newDecimalPos <= 0 -> { + sb.append("0.") + repeat(-newDecimalPos) { sb.append('0') } + sb.append(allDigits) + } + newDecimalPos >= allDigits.length -> { + sb.append(allDigits) + repeat(newDecimalPos - allDigits.length) { sb.append('0') } + sb.append(".0") + } + else -> { + sb.append(allDigits, 0, newDecimalPos) + sb.append('.') + sb.append(allDigits, newDecimalPos, allDigits.length) + } + } + + return sb.toString() +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Stats.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Stats.kt new file mode 100644 index 00000000000..63e485a621a --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Stats.kt @@ -0,0 +1,182 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import kotlinx.atomicfu.locks.SynchronizedObject +import kotlinx.atomicfu.locks.synchronized +import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds +import kotlin.time.TimeMark +import kotlin.time.TimeSource + +/** A single latency sample with its associated metadata (e.g. reducer name). */ +public data class DurationSample(val duration: Duration, val metadata: String) + +/** Min/max pair from a [NetworkRequestTracker] query. */ +public data class MinMaxResult(val min: DurationSample, val max: DurationSample) + +private class RequestEntry(val startTime: TimeMark, val metadata: String) + +/** + * Tracks request latencies over sliding time windows. + * Thread-safe — all reads and writes are synchronized. + * + * Use [minMaxTimes] to query min/max latency within a recent window, + * or [allTimeMinMax] for the lifetime extremes. + */ +public class NetworkRequestTracker internal constructor( + private val timeSource: TimeSource = TimeSource.Monotonic, +) : SynchronizedObject() { + internal constructor() : this(TimeSource.Monotonic) + + public companion object { + private const val MAX_TRACKERS = 16 + } + + private var allTimeMin: DurationSample? = null + private var allTimeMax: DurationSample? = null + + private val trackers = mutableMapOf() + private var totalSamples = 0 + private var nextRequestId = 0u + private val requests = mutableMapOf() + + /** All-time min/max latency, or `null` if no samples recorded yet. */ + public val allTimeMinMax: MinMaxResult? + get() = synchronized(this) { + val min = allTimeMin ?: return null + val max = allTimeMax ?: return null + MinMaxResult(min, max) + } + + /** Min/max latency within the last [lastSeconds] seconds, or `null` if no samples in that window. */ + public fun minMaxTimes(lastSeconds: Int): MinMaxResult? = synchronized(this) { + val tracker = trackers.getOrPut(lastSeconds) { + check(trackers.size < MAX_TRACKERS) { + "Cannot track more than $MAX_TRACKERS distinct window sizes" + } + WindowTracker(lastSeconds, timeSource) + } + tracker.getMinMax() + } + + /** Total number of latency samples recorded. */ + public val sampleCount: Int get() = synchronized(this) { totalSamples } + + /** Number of requests that have been started but not yet completed. */ + public val requestsAwaitingResponse: Int get() = synchronized(this) { requests.size } + + internal fun startTrackingRequest(metadata: String = ""): UInt { + synchronized(this) { + val requestId = nextRequestId++ + requests[requestId] = RequestEntry( + startTime = timeSource.markNow(), + metadata = metadata, + ) + return requestId + } + } + + internal fun finishTrackingRequest(requestId: UInt, metadata: String? = null): Boolean { + synchronized(this) { + val entry = requests.remove(requestId) ?: return false + val duration = entry.startTime.elapsedNow() + val resolvedMetadata = metadata ?: entry.metadata + insertSampleLocked(duration, resolvedMetadata) + return true + } + } + + internal fun insertSample(duration: Duration, metadata: String = "") { + synchronized(this) { + insertSampleLocked(duration, metadata) + } + } + + private fun insertSampleLocked(duration: Duration, metadata: String) { + totalSamples++ + val sample = DurationSample(duration, metadata) + + val currentMin = allTimeMin + if (currentMin == null || duration < currentMin.duration) { + allTimeMin = sample + } + val currentMax = allTimeMax + if (currentMax == null || duration > currentMax.duration) { + allTimeMax = sample + } + + for (tracker in trackers.values) { + tracker.insertSample(duration, metadata) + } + } + + private class WindowTracker(windowSeconds: Int, private val timeSource: TimeSource) { + val window: Duration = windowSeconds.seconds + var lastReset: TimeMark = timeSource.markNow() + + var lastWindowMin: DurationSample? = null + private set + var lastWindowMax: DurationSample? = null + private set + var thisWindowMin: DurationSample? = null + private set + var thisWindowMax: DurationSample? = null + private set + + fun insertSample(duration: Duration, metadata: String) { + maybeRotate() + val sample = DurationSample(duration, metadata) + + val currentMin = thisWindowMin + if (currentMin == null || duration < currentMin.duration) { + thisWindowMin = sample + } + val currentMax = thisWindowMax + if (currentMax == null || duration > currentMax.duration) { + thisWindowMax = sample + } + } + + fun getMinMax(): MinMaxResult? { + maybeRotate() + val min = lastWindowMin ?: return null + val max = lastWindowMax ?: return null + return MinMaxResult(min, max) + } + + private fun maybeRotate() { + val elapsed = lastReset.elapsedNow() + if (elapsed >= window) { + if (elapsed >= window * 2) { + // More than one full window passed — no data in the immediately + // preceding window, so lastWindow should be empty. + lastWindowMin = null + lastWindowMax = null + } else { + lastWindowMin = thisWindowMin + lastWindowMax = thisWindowMax + } + thisWindowMin = null + thisWindowMax = null + lastReset = timeSource.markNow() + } + } + } +} + +/** Aggregated latency trackers for each category of SpacetimeDB operation. */ +public class Stats { + /** Tracks round-trip latency for reducer calls. */ + public val reducerRequestTracker: NetworkRequestTracker = NetworkRequestTracker() + + /** Tracks round-trip latency for procedure calls. */ + public val procedureRequestTracker: NetworkRequestTracker = NetworkRequestTracker() + + /** Tracks round-trip latency for subscription requests. */ + public val subscriptionRequestTracker: NetworkRequestTracker = NetworkRequestTracker() + + /** Tracks round-trip latency for one-off query requests. */ + public val oneOffRequestTracker: NetworkRequestTracker = NetworkRequestTracker() + + /** Tracks time spent applying incoming server messages to the client cache. */ + public val applyMessageTracker: NetworkRequestTracker = NetworkRequestTracker() +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SubscriptionBuilder.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SubscriptionBuilder.kt new file mode 100644 index 00000000000..18316331323 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SubscriptionBuilder.kt @@ -0,0 +1,52 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +/** + * Builder for configuring subscription callbacks before subscribing. + */ +public class SubscriptionBuilder internal constructor( + private val connection: DbConnection, +) { + private val onAppliedCallbacks = mutableListOf<(EventContext.SubscribeApplied) -> Unit>() + private val onErrorCallbacks = mutableListOf<(EventContext.Error, SubscriptionError) -> Unit>() + private val querySqls = mutableListOf() + + /** Registers a callback invoked when the subscription's initial rows are applied. */ + public fun onApplied(cb: (EventContext.SubscribeApplied) -> Unit): SubscriptionBuilder = apply { + onAppliedCallbacks.add(cb) + } + + /** Registers a callback invoked when the subscription encounters an error. */ + public fun onError(cb: (EventContext.Error, SubscriptionError) -> Unit): SubscriptionBuilder = apply { + onErrorCallbacks.add(cb) + } + + /** + * Add a raw SQL query to the subscription. + */ + public fun addQuery(sql: String): SubscriptionBuilder = apply { + querySqls.add(sql) + } + + /** + * Subscribe with the accumulated queries. + * Requires at least one query added via [addQuery]. + */ + public fun subscribe(): SubscriptionHandle { + check(querySqls.isNotEmpty()) { "No queries added. Use addQuery() before subscribe()." } + return connection.subscribe(querySqls, onApplied = onAppliedCallbacks, onError = onErrorCallbacks) + } + + /** + * Subscribe to a single raw SQL query. + */ + public fun subscribe(query: String): SubscriptionHandle = + subscribe(listOf(query)) + + /** + * Subscribe to the given raw SQL queries. + */ + public fun subscribe(queries: List): SubscriptionHandle { + return connection.subscribe(queries, onApplied = onAppliedCallbacks, onError = onErrorCallbacks) + } + +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SubscriptionHandle.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SubscriptionHandle.kt new file mode 100644 index 00000000000..36362e348b9 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SubscriptionHandle.kt @@ -0,0 +1,95 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QuerySetId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.UnsubscribeFlags +import kotlinx.atomicfu.atomic + +/** + * Subscription lifecycle state. + */ +public enum class SubscriptionState { + PENDING, + ACTIVE, + UNSUBSCRIBING, + ENDED, +} + +/** + * Handle to a subscription. + * + * Tracks the lifecycle: Pending -> Active -> Ended. + * - Active after SubscribeApplied received + * - Ended after UnsubscribeApplied or SubscriptionError received + */ +public class SubscriptionHandle internal constructor( + /** The server-assigned query set identifier for this subscription. */ + @InternalSpacetimeApi + public val querySetId: QuerySetId, + /** The SQL queries this subscription is tracking. */ + public val queries: List, + private val connection: DbConnection, + private val onAppliedCallbacks: List<(EventContext.SubscribeApplied) -> Unit> = emptyList(), + private val onErrorCallbacks: List<(EventContext.Error, SubscriptionError) -> Unit> = emptyList(), +) { + private val _state = atomic(SubscriptionState.PENDING) + /** The current lifecycle state of this subscription. */ + public val state: SubscriptionState get() = _state.value + /** Whether the subscription is pending (sent but not yet confirmed by the server). */ + public val isPending: Boolean get() = _state.value == SubscriptionState.PENDING + /** Whether the subscription is active (confirmed and receiving updates). */ + public val isActive: Boolean get() = _state.value == SubscriptionState.ACTIVE + /** Whether an unsubscribe request has been sent but not yet confirmed. */ + public val isUnsubscribing: Boolean get() = _state.value == SubscriptionState.UNSUBSCRIBING + /** Whether the subscription has ended (unsubscribed or errored). */ + public val isEnded: Boolean get() = _state.value == SubscriptionState.ENDED + + private val _onEndCallback = atomic<((EventContext.UnsubscribeApplied) -> Unit)?>(null) + + /** + * Unsubscribe from this subscription. + * The onEnd callback will fire when the server confirms. + */ + public fun unsubscribe() { + doUnsubscribe() + } + + /** + * Unsubscribe and register a callback for when it completes. + */ + public fun unsubscribeThen( + onEnd: (EventContext.UnsubscribeApplied) -> Unit, + ) { + doUnsubscribe(onEnd) + } + + private fun doUnsubscribe( + onEnd: ((EventContext.UnsubscribeApplied) -> Unit)? = null, + ) { + if (!_state.compareAndSet(SubscriptionState.ACTIVE, SubscriptionState.UNSUBSCRIBING)) { + error("Cannot unsubscribe: subscription is ${_state.value}") + } + // Set callback AFTER the CAS succeeds. This is safe because handleEnd() + // only fires after the server receives our Unsubscribe message (sent below). + if (onEnd != null) _onEndCallback.value = onEnd + connection.unsubscribe(this, UnsubscribeFlags.SendDroppedRows) + } + + internal suspend fun handleApplied(ctx: EventContext.SubscribeApplied) { + _state.value = SubscriptionState.ACTIVE + for (cb in onAppliedCallbacks) connection.runUserCallback { cb(ctx) } + } + + internal suspend fun handleError(ctx: EventContext.Error, error: SubscriptionError) { + _state.value = SubscriptionState.ENDED + for (cb in onErrorCallbacks) connection.runUserCallback { cb(ctx, error) } + } + + internal suspend fun handleEnd(ctx: EventContext.UnsubscribeApplied) { + _state.value = SubscriptionState.ENDED + _onEndCallback.value?.let { connection.runUserCallback { it.invoke(ctx) } } + } + + internal fun markEnded() { + _state.value = SubscriptionState.ENDED + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TableQuery.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TableQuery.kt new file mode 100644 index 00000000000..da8642bc367 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TableQuery.kt @@ -0,0 +1,248 @@ +@file:OptIn(kotlin.experimental.ExperimentalTypeInference::class) + +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import kotlin.jvm.JvmName + +/** + * A query that can be converted to a SQL string. + * Implemented by [Table], [FromWhere], [LeftSemiJoin], and [RightSemiJoin]. + */ +public interface Query<@Suppress("unused") TRow> { + /** Converts this query to its SQL string representation. */ + public fun toSql(): String +} + +/** + * A type-safe query reference for a specific table. + * Generated code creates these via per-table methods on `QueryBuilder`. + * + * @param TRow the row type of this table + * @param TCols the column accessor class (generated per-table) + * @param TIxCols the indexed column accessor class (generated per-table) + */ +public class Table( + private val tableName: String, + internal val cols: TCols, + internal val ixCols: TIxCols, +) : Query { + internal val tableRefSql: String get() = SqlFormat.quoteIdent(tableName) + + override fun toSql(): String = "SELECT * FROM ${SqlFormat.quoteIdent(tableName)}" + + /** Adds a WHERE clause to this table query. */ + public fun where(predicate: (TCols) -> BoolExpr): FromWhere = + FromWhere(this, predicate(cols)) + + public fun where(predicate: (TCols, TIxCols) -> BoolExpr): FromWhere = + FromWhere(this, predicate(cols, ixCols)) + + @OverloadResolutionByLambdaReturnType + @JvmName("whereCol") + public fun where(predicate: (TCols) -> Col): FromWhere = + FromWhere(this, predicate(cols).eq(SqlLit.bool(true))) + + @OverloadResolutionByLambdaReturnType + @JvmName("whereColIx") + public fun where(predicate: (TCols, TIxCols) -> Col): FromWhere = + FromWhere(this, predicate(cols, ixCols).eq(SqlLit.bool(true))) + + @OverloadResolutionByLambdaReturnType + @JvmName("whereIxColIx") + public fun where(predicate: (TCols, TIxCols) -> IxCol): FromWhere = + FromWhere(this, predicate(cols, ixCols).eq(SqlLit.bool(true))) + + /** Alias for [where]; adds a WHERE clause to this table query. */ + public fun filter(predicate: (TCols) -> BoolExpr): FromWhere = + FromWhere(this, predicate(cols)) + + public fun filter(predicate: (TCols, TIxCols) -> BoolExpr): FromWhere = + FromWhere(this, predicate(cols, ixCols)) + + @OverloadResolutionByLambdaReturnType + @JvmName("filterCol") + public fun filter(predicate: (TCols) -> Col): FromWhere = + FromWhere(this, predicate(cols).eq(SqlLit.bool(true))) + + @OverloadResolutionByLambdaReturnType + @JvmName("filterColIx") + public fun filter(predicate: (TCols, TIxCols) -> Col): FromWhere = + FromWhere(this, predicate(cols, ixCols).eq(SqlLit.bool(true))) + + @OverloadResolutionByLambdaReturnType + @JvmName("filterIxColIx") + public fun filter(predicate: (TCols, TIxCols) -> IxCol): FromWhere = + FromWhere(this, predicate(cols, ixCols).eq(SqlLit.bool(true))) + + /** Creates a left semi-join with [right], returning rows from this table where a match exists. */ + public fun leftSemijoin( + right: Table, + on: (TIxCols, TRIxCols) -> IxJoinEq, + ): LeftSemiJoin = + LeftSemiJoin(this, right, on(ixCols, right.ixCols)) + + /** Creates a right semi-join with [right], returning rows from the right table where a match exists. */ + public fun rightSemijoin( + right: Table, + on: (TIxCols, TRIxCols) -> IxJoinEq, + ): RightSemiJoin = + RightSemiJoin(this, right, on(ixCols, right.ixCols)) +} + +/** + * A table query with a WHERE clause. + * Created by calling [Table.where] or [Table.filter]. + * Additional [where] calls chain predicates with AND. + */ +public class FromWhere( + private val table: Table, + private val expr: BoolExpr, +) : Query { + override fun toSql(): String = "${table.toSql()} WHERE ${expr.sql}" + + /** Chains an additional AND predicate onto this query's WHERE clause. */ + public fun where(predicate: (TCols) -> BoolExpr): FromWhere = + FromWhere(table, expr.and(predicate(table.cols))) + + public fun where(predicate: (TCols, TIxCols) -> BoolExpr): FromWhere = + FromWhere(table, expr.and(predicate(table.cols, table.ixCols))) + + @OverloadResolutionByLambdaReturnType + @JvmName("whereCol") + public fun where(predicate: (TCols) -> Col): FromWhere = + FromWhere(table, expr.and(predicate(table.cols).eq(SqlLit.bool(true)))) + + @OverloadResolutionByLambdaReturnType + @JvmName("whereColIx") + public fun where(predicate: (TCols, TIxCols) -> Col): FromWhere = + FromWhere(table, expr.and(predicate(table.cols, table.ixCols).eq(SqlLit.bool(true)))) + + @OverloadResolutionByLambdaReturnType + @JvmName("whereIxColIx") + public fun where(predicate: (TCols, TIxCols) -> IxCol): FromWhere = + FromWhere(table, expr.and(predicate(table.cols, table.ixCols).eq(SqlLit.bool(true)))) + + /** Alias for [where]; chains an additional AND predicate onto this query's WHERE clause. */ + public fun filter(predicate: (TCols) -> BoolExpr): FromWhere = + FromWhere(table, expr.and(predicate(table.cols))) + + public fun filter(predicate: (TCols, TIxCols) -> BoolExpr): FromWhere = + FromWhere(table, expr.and(predicate(table.cols, table.ixCols))) + + @OverloadResolutionByLambdaReturnType + @JvmName("filterCol") + public fun filter(predicate: (TCols) -> Col): FromWhere = + FromWhere(table, expr.and(predicate(table.cols).eq(SqlLit.bool(true)))) + + @OverloadResolutionByLambdaReturnType + @JvmName("filterColIx") + public fun filter(predicate: (TCols, TIxCols) -> Col): FromWhere = + FromWhere(table, expr.and(predicate(table.cols, table.ixCols).eq(SqlLit.bool(true)))) + + @OverloadResolutionByLambdaReturnType + @JvmName("filterIxColIx") + public fun filter(predicate: (TCols, TIxCols) -> IxCol): FromWhere = + FromWhere(table, expr.and(predicate(table.cols, table.ixCols).eq(SqlLit.bool(true)))) + + /** Creates a left semi-join with [right], preserving this query's WHERE clause. */ + public fun leftSemijoin( + right: Table, + on: (TIxCols, TRIxCols) -> IxJoinEq, + ): LeftSemiJoin = + LeftSemiJoin(this.table, right, on(table.ixCols, right.ixCols), expr) + + /** Creates a right semi-join with [right], preserving this query's WHERE clause. */ + public fun rightSemijoin( + right: Table, + on: (TIxCols, TRIxCols) -> IxJoinEq, + ): RightSemiJoin = + RightSemiJoin(this.table, right, on(table.ixCols, right.ixCols), expr) +} + +/** + * A left semi-join query. Returns rows from the left table. + * Created by calling [Table.leftSemijoin] or [FromWhere.leftSemijoin]. + */ +public class LeftSemiJoin( + private val left: Table, + private val right: Table, + private val join: IxJoinEq, + private val whereExpr: BoolExpr? = null, +) : Query { + override fun toSql(): String { + val base = "SELECT ${left.tableRefSql}.* FROM ${left.tableRefSql} JOIN ${right.tableRefSql} ON ${join.leftRefSql} = ${join.rightRefSql}" + return if (whereExpr != null) "$base WHERE ${whereExpr.sql}" else base + } + + /** Adds a WHERE predicate on the left table's columns. */ + public fun where(predicate: (TLCols) -> BoolExpr): LeftSemiJoin { + val newExpr = predicate(left.cols) + return LeftSemiJoin(left, right, join, whereExpr?.and(newExpr) ?: newExpr) + } + + @OverloadResolutionByLambdaReturnType + @JvmName("whereCol") + public fun where(predicate: (TLCols) -> Col): LeftSemiJoin { + val newExpr = predicate(left.cols).eq(SqlLit.bool(true)) + return LeftSemiJoin(left, right, join, whereExpr?.and(newExpr) ?: newExpr) + } + + /** Alias for [where]; adds a WHERE predicate on the left table's columns. */ + public fun filter(predicate: (TLCols) -> BoolExpr): LeftSemiJoin { + val newExpr = predicate(left.cols) + return LeftSemiJoin(left, right, join, whereExpr?.and(newExpr) ?: newExpr) + } + + @OverloadResolutionByLambdaReturnType + @JvmName("filterCol") + public fun filter(predicate: (TLCols) -> Col): LeftSemiJoin { + val newExpr = predicate(left.cols).eq(SqlLit.bool(true)) + return LeftSemiJoin(left, right, join, whereExpr?.and(newExpr) ?: newExpr) + } +} + +/** + * A right semi-join query. Returns rows from the right table. + * Created by calling [Table.rightSemijoin] or [FromWhere.rightSemijoin]. + */ +public class RightSemiJoin( + private val left: Table, + private val right: Table, + private val join: IxJoinEq, + private val leftWhereExpr: BoolExpr? = null, + private val rightWhereExpr: BoolExpr? = null, +) : Query { + override fun toSql(): String { + val base = "SELECT ${right.tableRefSql}.* FROM ${left.tableRefSql} JOIN ${right.tableRefSql} ON ${join.leftRefSql} = ${join.rightRefSql}" + val conditions = mutableListOf() + if (leftWhereExpr != null) conditions.add(leftWhereExpr.sql) + if (rightWhereExpr != null) conditions.add(rightWhereExpr.sql) + return if (conditions.isEmpty()) base else "$base WHERE ${conditions.joinToString(" AND ")}" + } + + /** Adds a WHERE predicate on the right table's columns. */ + public fun where(predicate: (TRCols) -> BoolExpr): RightSemiJoin { + val newExpr = predicate(right.cols) + return RightSemiJoin(left, right, join, leftWhereExpr, rightWhereExpr?.and(newExpr) ?: newExpr) + } + + @OverloadResolutionByLambdaReturnType + @JvmName("whereCol") + public fun where(predicate: (TRCols) -> Col): RightSemiJoin { + val newExpr = predicate(right.cols).eq(SqlLit.bool(true)) + return RightSemiJoin(left, right, join, leftWhereExpr, rightWhereExpr?.and(newExpr) ?: newExpr) + } + + /** Alias for [where]; adds a WHERE predicate on the right table's columns. */ + public fun filter(predicate: (TRCols) -> BoolExpr): RightSemiJoin { + val newExpr = predicate(right.cols) + return RightSemiJoin(left, right, join, leftWhereExpr, rightWhereExpr?.and(newExpr) ?: newExpr) + } + + @OverloadResolutionByLambdaReturnType + @JvmName("filterCol") + public fun filter(predicate: (TRCols) -> Col): RightSemiJoin { + val newExpr = predicate(right.cols).eq(SqlLit.bool(true)) + return RightSemiJoin(left, right, join, leftWhereExpr, rightWhereExpr?.and(newExpr) ?: newExpr) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/UInt128.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/UInt128.kt new file mode 100644 index 00000000000..6ff69f1d655 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/UInt128.kt @@ -0,0 +1,21 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import kotlin.jvm.JvmInline + +/** An unsigned 128-bit integer, backed by [BigInteger]. */ +@JvmInline +public value class UInt128(public val value: BigInteger) : Comparable { + /** Encodes this value to BSATN. */ + public fun encode(writer: BsatnWriter): Unit = writer.writeU128(value) + override fun compareTo(other: UInt128): Int = value.compareTo(other.value) + override fun toString(): String = value.toString() + + public companion object { + /** Decodes a [UInt128] from BSATN. */ + public fun decode(reader: BsatnReader): UInt128 = UInt128(reader.readU128()) + /** A zero-valued [UInt128]. */ + public val ZERO: UInt128 = UInt128(BigInteger.ZERO) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/UInt256.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/UInt256.kt new file mode 100644 index 00000000000..dcbd3b9099b --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/UInt256.kt @@ -0,0 +1,21 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import kotlin.jvm.JvmInline + +/** An unsigned 256-bit integer, backed by [BigInteger]. */ +@JvmInline +public value class UInt256(public val value: BigInteger) : Comparable { + /** Encodes this value to BSATN. */ + public fun encode(writer: BsatnWriter): Unit = writer.writeU256(value) + override fun compareTo(other: UInt256): Int = value.compareTo(other.value) + override fun toString(): String = value.toString() + + public companion object { + /** Decodes a [UInt256] from BSATN. */ + public fun decode(reader: BsatnReader): UInt256 = UInt256(reader.readU256()) + /** A zero-valued [UInt256]. */ + public val ZERO: UInt256 = UInt256(BigInteger.ZERO) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Util.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Util.kt new file mode 100644 index 00000000000..a2f92c501ea --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/Util.kt @@ -0,0 +1,33 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import kotlin.random.Random +import kotlin.time.Instant + +internal fun BigInteger.toHexString(byteWidth: Int): String { + require(signum() >= 0) { "toHexString requires a non-negative value, got $this" } + return toString(16).padStart(byteWidth * 2, '0') +} + +internal fun parseHexString(hex: String): BigInteger = BigInteger.parseString(hex, 16) +internal fun randomBigInteger(byteLength: Int): BigInteger { + val bytes = ByteArray(byteLength) + Random.nextBytes(bytes) + return BigInteger.fromByteArray(bytes, Sign.POSITIVE) +} + + +internal fun Instant.Companion.fromEpochMicroseconds(micros: Long): Instant { + val seconds = micros.floorDiv(1_000_000L) + val nanos = micros.mod(1_000_000L).toInt() * 1_000 + return fromEpochSeconds(seconds, nanos) +} + +private const val MAX_EPOCH_SECONDS_FOR_MICROS = Long.MAX_VALUE / 1_000_000L +private const val MIN_EPOCH_SECONDS_FOR_MICROS = Long.MIN_VALUE / 1_000_000L + +internal fun Instant.toEpochMicroseconds(): Long { + require(epochSeconds in MIN_EPOCH_SECONDS_FOR_MICROS..MAX_EPOCH_SECONDS_FOR_MICROS) { + "Timestamp $this is outside the representable microsecond range" + } + return epochSeconds * 1_000_000L + (nanosecondsOfSecond / 1_000) +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/bsatn/BsatnReader.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/bsatn/BsatnReader.kt new file mode 100644 index 00000000000..1e0a4b21a57 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/bsatn/BsatnReader.kt @@ -0,0 +1,195 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.BigInteger +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi + +/** + * Binary reader for BSATN decoding. All multi-byte values are little-endian. + */ +public class BsatnReader(internal var data: ByteArray, offset: Int = 0, private var limit: Int = data.size) { + + /** Current read position within the buffer. */ + @InternalSpacetimeApi + public var offset: Int = offset + private set + + /** Number of bytes remaining to be read. */ + @InternalSpacetimeApi + public val remaining: Int get() = limit - offset + + /** Resets this reader to decode from a new byte array from the beginning. */ + internal fun reset(newData: ByteArray) { + data = newData + offset = 0 + limit = newData.size + } + + /** Advances the read position by [n] bytes without returning data. */ + internal fun skip(n: Int) { + ensure(n) + offset += n + } + + private fun ensure(n: Int) { + check(n in 0..remaining) { "BsatnReader: need $n bytes but only $remaining remain" } + } + + /** Reads a BSATN boolean (1 byte, nonzero = true). */ + public fun readBool(): Boolean { + ensure(1) + val b = data[offset].toInt() and 0xFF + offset += 1 + return b != 0 + } + + /** Reads a single signed byte. */ + public fun readByte(): Byte { + ensure(1) + val b = data[offset] + offset += 1 + return b + } + + /** Reads a signed 8-bit integer. */ + public fun readI8(): Byte = readByte() + + /** Reads an unsigned 8-bit integer. */ + public fun readU8(): UByte { + ensure(1) + val b = data[offset].toUByte() + offset += 1 + return b + } + + /** Reads a signed 16-bit integer (little-endian). */ + public fun readI16(): Short { + ensure(2) + val b0 = data[offset].toInt() and 0xFF + val b1 = data[offset + 1].toInt() and 0xFF + offset += 2 + return (b0 or (b1 shl 8)).toShort() + } + + /** Reads an unsigned 16-bit integer (little-endian). */ + public fun readU16(): UShort = readI16().toUShort() + + /** Reads a signed 32-bit integer (little-endian). */ + public fun readI32(): Int { + ensure(4) + val b0 = data[offset].toLong() and 0xFF + val b1 = data[offset + 1].toLong() and 0xFF + val b2 = data[offset + 2].toLong() and 0xFF + val b3 = data[offset + 3].toLong() and 0xFF + offset += 4 + return (b0 or (b1 shl 8) or (b2 shl 16) or (b3 shl 24)).toInt() + } + + /** Reads an unsigned 32-bit integer (little-endian). */ + public fun readU32(): UInt = readI32().toUInt() + + /** Reads a signed 64-bit integer (little-endian). */ + public fun readI64(): Long { + ensure(8) + var result = 0L + for (i in 0 until 8) { + result = result or ((data[offset + i].toLong() and 0xFF) shl (i * 8)) + } + offset += 8 + return result + } + + /** Reads an unsigned 64-bit integer (little-endian). */ + public fun readU64(): ULong = readI64().toULong() + + /** Reads a 32-bit IEEE 754 float (little-endian). */ + public fun readF32(): Float = Float.fromBits(readI32()) + + /** Reads a 64-bit IEEE 754 double (little-endian). */ + public fun readF64(): Double = Double.fromBits(readI64()) + + /** Reads a signed 128-bit integer (little-endian) as a [BigInteger]. */ + public fun readI128(): BigInteger { + ensure(16) + val result = BigInteger.fromLeBytes(data, offset, 16) + offset += 16 + return result + } + + /** Reads an unsigned 128-bit integer (little-endian) as a [BigInteger]. */ + public fun readU128(): BigInteger { + ensure(16) + val result = BigInteger.fromLeBytesUnsigned(data, offset, 16) + offset += 16 + return result + } + + /** Reads a signed 256-bit integer (little-endian) as a [BigInteger]. */ + public fun readI256(): BigInteger { + ensure(32) + val result = BigInteger.fromLeBytes(data, offset, 32) + offset += 32 + return result + } + + /** Reads an unsigned 256-bit integer (little-endian) as a [BigInteger]. */ + public fun readU256(): BigInteger { + ensure(32) + val result = BigInteger.fromLeBytesUnsigned(data, offset, 32) + offset += 32 + return result + } + + /** Reads a BSATN length-prefixed UTF-8 string. */ + public fun readString(): String { + val len = readU32() + check(len <= Int.MAX_VALUE.toUInt()) { "String length $len exceeds maximum supported size" } + val bytes = readRawBytes(len.toInt()) + return bytes.decodeToString() + } + + /** Reads a BSATN length-prefixed byte array. */ + public fun readByteArray(): ByteArray { + val len = readU32() + check(len <= Int.MAX_VALUE.toUInt()) { "Byte array length $len exceeds maximum supported size" } + return readRawBytes(len.toInt()) + } + + private fun readRawBytes(length: Int): ByteArray { + ensure(length) + val result = data.copyOfRange(offset, offset + length) + offset += length + return result + } + + /** + * Returns a zero-copy view of the underlying buffer. + * The returned BsatnReader shares the same backing array — no allocation. + */ + internal fun readRawBytesView(length: Int): BsatnReader { + ensure(length) + val view = BsatnReader(data, offset, offset + length) + offset += length + return view + } + + /** + * Returns a copy of the underlying buffer between [from] and [to]. + * Used when a materialized ByteArray is needed (e.g. for content-based keying). + */ + internal fun sliceArray(from: Int, to: Int): ByteArray { + check(to in from..limit) { + "sliceArray($from, $to) out of view bounds (limit=$limit)" + } + return data.copyOfRange(from, to) + } + + /** Reads a sum-type tag byte. */ + public fun readSumTag(): UByte = readU8() + + /** Reads a BSATN array length prefix (U32), returned as Int for indexing. */ + public fun readArrayLen(): Int { + val len = readU32() + check(len <= Int.MAX_VALUE.toUInt()) { "Array length $len exceeds maximum supported size" } + return len.toInt() + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/bsatn/BsatnWriter.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/bsatn/BsatnWriter.kt new file mode 100644 index 00000000000..5661c722d2f --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/bsatn/BsatnWriter.kt @@ -0,0 +1,191 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.BigInteger +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import kotlin.io.encoding.Base64 +import kotlin.io.encoding.ExperimentalEncodingApi + +/** + * Resizable buffer for BSATN writing. Doubles capacity on overflow. + */ +internal class ResizableBuffer(initialCapacity: Int) { + var buffer: ByteArray = ByteArray(initialCapacity) + private set + + val capacity: Int get() = buffer.size + + fun grow(newSize: Int) { + if (newSize <= buffer.size) return + val newCapacity = maxOf(buffer.size * 2, newSize) + buffer = buffer.copyOf(newCapacity) + } +} + +/** + * Binary writer for BSATN encoding. + * Little-endian, length-prefixed strings/byte arrays, auto-growing buffer. + */ +public class BsatnWriter(initialCapacity: Int = 256) { + private var buffer = ResizableBuffer(initialCapacity) + /** Number of bytes written so far. */ + @InternalSpacetimeApi + public var offset: Int = 0 + private set + + private fun expandBuffer(additionalCapacity: Int) { + val minCapacity = offset + additionalCapacity + if (minCapacity > buffer.capacity) buffer.grow(minCapacity) + } + + // ---------- Primitive Writes ---------- + + /** Writes a boolean as a single byte (1 = true, 0 = false). */ + public fun writeBool(value: Boolean) { + expandBuffer(1) + buffer.buffer[offset] = if (value) 1 else 0 + offset += 1 + } + + /** Writes a single signed byte. */ + public fun writeByte(value: Byte) { + expandBuffer(1) + buffer.buffer[offset] = value + offset += 1 + } + + /** Writes a single unsigned byte. */ + public fun writeUByte(value: UByte) { + writeByte(value.toByte()) + } + + /** Writes a signed 8-bit integer. */ + public fun writeI8(value: Byte): Unit = writeByte(value) + /** Writes an unsigned 8-bit integer. */ + public fun writeU8(value: UByte): Unit = writeUByte(value) + + /** Writes a signed 16-bit integer (little-endian). */ + public fun writeI16(value: Short) { + expandBuffer(2) + val v = value.toInt() + buffer.buffer[offset] = (v and 0xFF).toByte() + buffer.buffer[offset + 1] = ((v shr 8) and 0xFF).toByte() + offset += 2 + } + + /** Writes an unsigned 16-bit integer (little-endian). */ + public fun writeU16(value: UShort): Unit = writeI16(value.toShort()) + + /** Writes a signed 32-bit integer (little-endian). */ + public fun writeI32(value: Int) { + expandBuffer(4) + buffer.buffer[offset] = (value and 0xFF).toByte() + buffer.buffer[offset + 1] = ((value shr 8) and 0xFF).toByte() + buffer.buffer[offset + 2] = ((value shr 16) and 0xFF).toByte() + buffer.buffer[offset + 3] = ((value shr 24) and 0xFF).toByte() + offset += 4 + } + + /** Writes an unsigned 32-bit integer (little-endian). */ + public fun writeU32(value: UInt): Unit = writeI32(value.toInt()) + + /** Writes a signed 64-bit integer (little-endian). */ + public fun writeI64(value: Long) { + expandBuffer(8) + for (i in 0 until 8) { + buffer.buffer[offset + i] = ((value shr (i * 8)) and 0xFF).toByte() + } + offset += 8 + } + + /** Writes an unsigned 64-bit integer (little-endian). */ + public fun writeU64(value: ULong): Unit = writeI64(value.toLong()) + + /** Writes a 32-bit IEEE 754 float (little-endian). */ + public fun writeF32(value: Float): Unit = writeI32(value.toRawBits()) + + /** Writes a 64-bit IEEE 754 double (little-endian). */ + public fun writeF64(value: Double): Unit = writeI64(value.toRawBits()) + + // ---------- Big Integer Writes ---------- + + /** Writes a signed 128-bit integer (little-endian). */ + public fun writeI128(value: BigInteger): Unit = writeSignedBigIntLE(value, 16) + + /** Writes an unsigned 128-bit integer (little-endian). */ + public fun writeU128(value: BigInteger): Unit = writeUnsignedBigIntLE(value, 16) + + /** Writes a signed 256-bit integer (little-endian). */ + public fun writeI256(value: BigInteger): Unit = writeSignedBigIntLE(value, 32) + + /** Writes an unsigned 256-bit integer (little-endian). */ + public fun writeU256(value: BigInteger): Unit = writeUnsignedBigIntLE(value, 32) + + private fun writeSignedBigIntLE(value: BigInteger, byteSize: Int) { + require(value.fitsInSignedBytes(byteSize)) { + "Signed value does not fit in $byteSize bytes: $value" + } + expandBuffer(byteSize) + value.writeLeBytes(buffer.buffer, offset, byteSize) + offset += byteSize + } + + private fun writeUnsignedBigIntLE(value: BigInteger, byteSize: Int) { + require(value.signum() >= 0) { + "Unsigned value must be non-negative: $value" + } + require(value.fitsInUnsignedBytes(byteSize)) { + "Unsigned value does not fit in $byteSize bytes: $value" + } + expandBuffer(byteSize) + value.writeLeBytes(buffer.buffer, offset, byteSize) + offset += byteSize + } + + // ---------- Strings / Byte Arrays ---------- + + /** Length-prefixed string (U32 length + UTF-8 bytes) */ + public fun writeString(value: String) { + val bytes = value.encodeToByteArray() + writeU32(bytes.size.toUInt()) + writeRawBytes(bytes) + } + + /** Length-prefixed byte array (U32 length + raw bytes) */ + public fun writeByteArray(value: ByteArray) { + writeU32(value.size.toUInt()) + writeRawBytes(value) + } + + /** Raw bytes, no length prefix */ + internal fun writeRawBytes(bytes: ByteArray) { + expandBuffer(bytes.size) + bytes.copyInto(buffer.buffer, offset) + offset += bytes.size + } + + // ---------- Utilities ---------- + + /** Writes a sum-type tag byte. */ + public fun writeSumTag(tag: UByte): Unit = writeU8(tag) + + /** Writes a BSATN array length prefix (U32). */ + public fun writeArrayLen(length: Int) { + require(length >= 0) { "Array length must be non-negative, got $length" } + writeU32(length.toUInt()) + } + + /** Return the written buffer up to current offset */ + public fun toByteArray(): ByteArray = buffer.buffer.copyOf(offset) + + /** Returns the written bytes as a Base64-encoded string. */ + @OptIn(ExperimentalEncodingApi::class) + @InternalSpacetimeApi + public fun toBase64(): String = Base64.encode(toByteArray()) + + /** Resets this writer, discarding all written data and re-allocating the buffer. */ + @InternalSpacetimeApi + public fun reset(initialCapacity: Int = 256) { + buffer = ResizableBuffer(initialCapacity) + offset = 0 + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/ClientMessage.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/ClientMessage.kt new file mode 100644 index 00000000000..6d1667ac0a4 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/ClientMessage.kt @@ -0,0 +1,149 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Opaque identifier for a subscription query set. */ +@InternalSpacetimeApi +public data class QuerySetId(val id: UInt) { + /** Encodes this value to BSATN. */ + public fun encode(writer: BsatnWriter): Unit = writer.writeU32(id) +} + +/** Flags controlling server behavior when unsubscribing. */ +internal sealed interface UnsubscribeFlags { + /** Default unsubscribe behavior (rows are silently dropped). */ + data object Default : UnsubscribeFlags + /** Request that the server send the dropped rows back before completing. */ + data object SendDroppedRows : UnsubscribeFlags + + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + when (this) { + is Default -> writer.writeSumTag(0u) + is SendDroppedRows -> writer.writeSumTag(1u) + } + } +} + +/** + * Messages sent from the client to the SpacetimeDB server. + * Variant tags match the wire protocol (0=Subscribe, 1=Unsubscribe, 2=OneOffQuery, 3=CallReducer, 4=CallProcedure). + */ +internal sealed interface ClientMessage { + + /** Encodes this message to BSATN. */ + fun encode(writer: BsatnWriter) + + /** Request to subscribe to a set of SQL queries. */ + data class Subscribe( + val requestId: UInt, + val querySetId: QuerySetId, + val queryStrings: List, + ) : ClientMessage { + override fun encode(writer: BsatnWriter) { + writer.writeSumTag(0u) + writer.writeU32(requestId) + querySetId.encode(writer) + writer.writeArrayLen(queryStrings.size) + for (s in queryStrings) writer.writeString(s) + } + } + + /** Request to unsubscribe from a query set. */ + data class Unsubscribe( + val requestId: UInt, + val querySetId: QuerySetId, + val flags: UnsubscribeFlags, + ) : ClientMessage { + override fun encode(writer: BsatnWriter) { + writer.writeSumTag(1u) + writer.writeU32(requestId) + querySetId.encode(writer) + flags.encode(writer) + } + } + + /** A single-shot SQL query that does not create a subscription. */ + data class OneOffQuery( + val requestId: UInt, + val queryString: String, + ) : ClientMessage { + override fun encode(writer: BsatnWriter) { + writer.writeSumTag(2u) + writer.writeU32(requestId) + writer.writeString(queryString) + } + } + + /** Request to invoke a reducer on the server. */ + data class CallReducer( + val requestId: UInt, + val flags: UByte, + val reducer: String, + val args: ByteArray, + ) : ClientMessage { + override fun encode(writer: BsatnWriter) { + writer.writeSumTag(3u) + writer.writeU32(requestId) + writer.writeU8(flags) + writer.writeString(reducer) + writer.writeByteArray(args) + } + + override fun equals(other: Any?): Boolean = + other is CallReducer && + requestId == other.requestId && + flags == other.flags && + reducer == other.reducer && + args.contentEquals(other.args) + + override fun hashCode(): Int { + var result = requestId.hashCode() + result = 31 * result + flags.hashCode() + result = 31 * result + reducer.hashCode() + result = 31 * result + args.contentHashCode() + return result + } + } + + /** Request to invoke a procedure on the server. */ + data class CallProcedure( + val requestId: UInt, + val flags: UByte, + val procedure: String, + val args: ByteArray, + ) : ClientMessage { + override fun encode(writer: BsatnWriter) { + writer.writeSumTag(4u) + writer.writeU32(requestId) + writer.writeU8(flags) + writer.writeString(procedure) + writer.writeByteArray(args) + } + + override fun equals(other: Any?): Boolean = + other is CallProcedure && + requestId == other.requestId && + flags == other.flags && + procedure == other.procedure && + args.contentEquals(other.args) + + override fun hashCode(): Int { + var result = requestId.hashCode() + result = 31 * result + flags.hashCode() + result = 31 * result + procedure.hashCode() + result = 31 * result + args.contentHashCode() + return result + } + } + + companion object { + /** Encodes a [ClientMessage] to a BSATN byte array. */ + fun encodeToBytes(message: ClientMessage): ByteArray { + val writer = BsatnWriter() + message.encode(writer) + return writer.toByteArray() + } + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/Compression.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/Compression.kt new file mode 100644 index 00000000000..2e86216e70f --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/Compression.kt @@ -0,0 +1,49 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.CompressionMode + +/** + * Compression tags matching the SpacetimeDB wire protocol. + * First byte of every WebSocket message indicates compression. + */ +internal object Compression { + /** No compression applied. */ + const val NONE: Byte = 0x00 + /** Brotli compression. */ + const val BROTLI: Byte = 0x01 + /** Gzip compression. */ + const val GZIP: Byte = 0x02 +} + +/** + * Result of decompressing a message: the payload bytes and the offset at which they start. + * For compressed messages, [data] is a freshly-allocated array and [offset] is 0. + * For uncompressed messages, [data] is the original array and [offset] skips the tag byte, + * avoiding an unnecessary allocation. + */ +internal class DecompressedPayload(val data: ByteArray, val offset: Int = 0) { + init { + require(offset in 0..data.size) { "offset $offset out of bounds for data of size ${data.size}" } + } + + /** Number of usable bytes in the payload (total data size minus the offset). */ + val size: Int get() = data.size - offset +} + +/** + * Strips the compression prefix byte and decompresses if needed. + * Returns the raw BSATN payload. + */ +internal expect fun decompressMessage(data: ByteArray): DecompressedPayload + +/** + * Default compression mode for this platform. + * Native targets default to NONE (no decompression support); JVM/Android default to GZIP. + */ +internal expect val defaultCompressionMode: CompressionMode + +/** + * Compression modes supported on this platform. + * The builder validates that the user-selected mode is in this set. + */ +internal expect val availableCompressionModes: Set diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/ServerMessage.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/ServerMessage.kt new file mode 100644 index 00000000000..6df250c4886 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/ServerMessage.kt @@ -0,0 +1,387 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.TimeDuration +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader + +/** Hint describing how rows are packed in a [BsatnRowList]. */ +@InternalSpacetimeApi +public sealed interface RowSizeHint { + /** All rows have the same fixed byte size. */ + public data class FixedSize(val size: UShort) : RowSizeHint + /** Variable-size rows; offsets indicate where each row ends. */ + public data class RowOffsets(val offsets: List) : RowSizeHint + + public companion object { + /** Decodes a [RowSizeHint] from BSATN. */ + public fun decode(reader: BsatnReader): RowSizeHint { + return when (val tag = reader.readSumTag().toInt()) { + 0 -> FixedSize(reader.readU16()) + 1 -> { + val len = reader.readArrayLen() + val offsets = List(len) { reader.readU64() } + RowOffsets(offsets) + } + else -> error("Unknown RowSizeHint tag: $tag") + } + } + } +} + +/** A BSATN-encoded list of rows with an associated [RowSizeHint]. */ +@InternalSpacetimeApi +public class BsatnRowList( + public val sizeHint: RowSizeHint, + private val rowsData: ByteArray, + private val rowsOffset: Int = 0, + private val rowsLimit: Int = rowsData.size, +) { + /** Total byte size of the row data. */ + public val rowsSize: Int get() = rowsLimit - rowsOffset + + /** Creates a fresh [BsatnReader] over the row data. Safe to call multiple times. */ + public val rowsReader: BsatnReader get() = BsatnReader(rowsData, rowsOffset, rowsLimit) + + public companion object { + /** Decodes a [BsatnRowList] from BSATN. */ + public fun decode(reader: BsatnReader): BsatnRowList { + val sizeHint = RowSizeHint.decode(reader) + val rawLen = reader.readU32() + check(rawLen <= Int.MAX_VALUE.toUInt()) { "BsatnRowList length $rawLen exceeds maximum supported size" } + val len = rawLen.toInt() + val data = reader.data + val offset = reader.offset + reader.skip(len) + return BsatnRowList(sizeHint, data, offset, offset + len) + } + } +} + +/** Rows belonging to a single table, identified by name. */ +@InternalSpacetimeApi +public data class SingleTableRows( + val table: String, + val rows: BsatnRowList, +) { + public companion object { + /** Decodes a [SingleTableRows] from BSATN. */ + public fun decode(reader: BsatnReader): SingleTableRows { + val table = reader.readString() + val rows = BsatnRowList.decode(reader) + return SingleTableRows(table, rows) + } + } +} + +/** Collection of rows grouped by table, returned from a query. */ +@InternalSpacetimeApi +public data class QueryRows( + val tables: List, +) { + public companion object { + /** Decodes a [QueryRows] from BSATN. */ + public fun decode(reader: BsatnReader): QueryRows { + val len = reader.readArrayLen() + val tables = List(len) { SingleTableRows.decode(reader) } + return QueryRows(tables) + } + } +} + +/** Result of a query: either successful rows or an error message. */ +@InternalSpacetimeApi +public sealed interface QueryResult { + /** Successful query result containing the returned rows. */ + public data class Ok(val rows: QueryRows) : QueryResult + /** Failed query result containing an error message. */ + public data class Err(val error: String) : QueryResult +} + +/** Row updates for a single table within a transaction. */ +@InternalSpacetimeApi +public sealed interface TableUpdateRows { + /** Inserts and deletes for a persistent (stored) table. */ + public data class PersistentTable( + val inserts: BsatnRowList, + val deletes: BsatnRowList, + ) : TableUpdateRows + + /** Events for an event (non-stored) table. */ + public data class EventTable( + val events: BsatnRowList, + ) : TableUpdateRows + + public companion object { + /** Decodes a [TableUpdateRows] from BSATN. */ + public fun decode(reader: BsatnReader): TableUpdateRows { + return when (val tag = reader.readSumTag().toInt()) { + 0 -> PersistentTable( + inserts = BsatnRowList.decode(reader), + deletes = BsatnRowList.decode(reader), + ) + 1 -> EventTable(events = BsatnRowList.decode(reader)) + else -> error("Unknown TableUpdateRows tag: $tag") + } + } + } +} + +/** Update for a single table: its name and the list of row changes. */ +@InternalSpacetimeApi +public data class TableUpdate( + val tableName: String, + val rows: List, +) { + public companion object { + /** Decodes a [TableUpdate] from BSATN. */ + public fun decode(reader: BsatnReader): TableUpdate { + val tableName = reader.readString() + val len = reader.readArrayLen() + val rows = List(len) { TableUpdateRows.decode(reader) } + return TableUpdate(tableName, rows) + } + } +} + +/** Table updates scoped to a single query set. */ +@InternalSpacetimeApi +public data class QuerySetUpdate( + val querySetId: QuerySetId, + val tables: List, +) { + public companion object { + /** Decodes a [QuerySetUpdate] from BSATN. */ + public fun decode(reader: BsatnReader): QuerySetUpdate { + val querySetId = QuerySetId(reader.readU32()) + val len = reader.readArrayLen() + val tables = List(len) { TableUpdate.decode(reader) } + return QuerySetUpdate(querySetId, tables) + } + } +} + +/** A complete transaction update containing changes across all affected query sets. */ +@InternalSpacetimeApi +public data class TransactionUpdate( + val querySets: List, +) { + public companion object { + /** Decodes a [TransactionUpdate] from BSATN. */ + public fun decode(reader: BsatnReader): TransactionUpdate { + val len = reader.readArrayLen() + val querySets = List(len) { QuerySetUpdate.decode(reader) } + return TransactionUpdate(querySets) + } + } +} + +/** Outcome of a reducer execution on the server. */ +@InternalSpacetimeApi +public sealed interface ReducerOutcome { + /** Reducer succeeded with a return value and transaction update. */ + public data class Ok( + val retValue: ByteArray, + val transactionUpdate: TransactionUpdate, + ) : ReducerOutcome { + override fun equals(other: Any?): Boolean = + other is Ok && + retValue.contentEquals(other.retValue) && + transactionUpdate == other.transactionUpdate + + override fun hashCode(): Int { + var result = retValue.contentHashCode() + result = 31 * result + transactionUpdate.hashCode() + return result + } + } + + /** Reducer succeeded with no return value and no table changes. */ + public data object OkEmpty : ReducerOutcome + + /** Reducer failed with a BSATN-encoded error. */ + public data class Err(val error: ByteArray) : ReducerOutcome { + override fun equals(other: Any?): Boolean = + other is Err && error.contentEquals(other.error) + + override fun hashCode(): Int = error.contentHashCode() + } + + /** Reducer encountered an internal server error. */ + public data class InternalError(val message: String) : ReducerOutcome + + public companion object { + /** Decodes a [ReducerOutcome] from BSATN. */ + public fun decode(reader: BsatnReader): ReducerOutcome { + return when (val tag = reader.readSumTag().toInt()) { + 0 -> Ok( + retValue = reader.readByteArray(), + transactionUpdate = TransactionUpdate.decode(reader), + ) + 1 -> OkEmpty + 2 -> Err(reader.readByteArray()) + 3 -> InternalError(reader.readString()) + else -> error("Unknown ReducerOutcome tag: $tag") + } + } + } +} + +/** Status of a procedure execution on the server. */ +@InternalSpacetimeApi +public sealed interface ProcedureStatus { + /** Procedure returned successfully with a BSATN-encoded value. */ + public data class Returned(val value: ByteArray) : ProcedureStatus { + override fun equals(other: Any?): Boolean = + other is Returned && value.contentEquals(other.value) + + override fun hashCode(): Int = value.contentHashCode() + } + + /** Procedure encountered an internal server error. */ + public data class InternalError(val message: String) : ProcedureStatus + + public companion object { + /** Decodes a [ProcedureStatus] from BSATN. */ + public fun decode(reader: BsatnReader): ProcedureStatus { + return when (val tag = reader.readSumTag().toInt()) { + 0 -> Returned(reader.readByteArray()) + 1 -> InternalError(reader.readString()) + else -> error("Unknown ProcedureStatus tag: $tag") + } + } + } +} + +/** + * Messages received from the SpacetimeDB server. + * Variant tags match the wire protocol (0=InitialConnection through 7=ProcedureResult). + */ +@InternalSpacetimeApi +public sealed interface ServerMessage { + + /** Server confirmed the connection and assigned identity/token. */ + public data class InitialConnection( + val identity: Identity, + val connectionId: ConnectionId, + val token: String, + ) : ServerMessage + + /** Server applied a subscription and returned the initial matching rows. */ + public data class SubscribeApplied( + val requestId: UInt, + val querySetId: QuerySetId, + val rows: QueryRows, + ) : ServerMessage + + /** Server confirmed an unsubscription, optionally returning dropped rows. */ + public data class UnsubscribeApplied( + val requestId: UInt, + val querySetId: QuerySetId, + val rows: QueryRows?, + ) : ServerMessage + + /** Server reported an error for a subscription. */ + public data class SubscriptionError( + val requestId: UInt?, + val querySetId: QuerySetId, + val error: String, + ) : ServerMessage + + /** A transaction update containing table changes from a server-side event. */ + public data class TransactionUpdateMsg( + val update: TransactionUpdate, + ) : ServerMessage + + /** Result of a one-off SQL query. */ + public data class OneOffQueryResult( + val requestId: UInt, + val result: QueryResult, + ) : ServerMessage + + /** Result of a reducer call, including timestamp and outcome. */ + public data class ReducerResultMsg( + val requestId: UInt, + val timestamp: Timestamp, + val result: ReducerOutcome, + ) : ServerMessage + + /** Result of a procedure call, including status and execution duration. */ + public data class ProcedureResultMsg( + val status: ProcedureStatus, + val timestamp: Timestamp, + val totalHostExecutionDuration: TimeDuration, + val requestId: UInt, + ) : ServerMessage + + public companion object { + /** Decodes a [ServerMessage] from BSATN. */ + public fun decode(reader: BsatnReader): ServerMessage { + return when (val tag = reader.readSumTag().toInt()) { + 0 -> InitialConnection( + identity = Identity.decode(reader), + connectionId = ConnectionId.decode(reader), + token = reader.readString(), + ) + 1 -> SubscribeApplied( + requestId = reader.readU32(), + querySetId = QuerySetId(reader.readU32()), + rows = QueryRows.decode(reader), + ) + 2 -> { + val requestId = reader.readU32() + val querySetId = QuerySetId(reader.readU32()) + // Option: tag 0 = Some, tag 1 = None + val rows = when (reader.readSumTag().toInt()) { + 0 -> QueryRows.decode(reader) + 1 -> null + else -> error("Invalid Option tag") + } + UnsubscribeApplied(requestId, querySetId, rows) + } + 3 -> { + // Option: tag 0 = Some, tag 1 = None + val requestId = when (reader.readSumTag().toInt()) { + 0 -> reader.readU32() + 1 -> null + else -> error("Invalid Option tag") + } + val querySetId = QuerySetId(reader.readU32()) + val error = reader.readString() + SubscriptionError(requestId, querySetId, error) + } + 4 -> TransactionUpdateMsg(TransactionUpdate.decode(reader)) + 5 -> { + val requestId = reader.readU32() + // Result: tag 0 = Ok, tag 1 = Err + val result = when (reader.readSumTag().toInt()) { + 0 -> QueryResult.Ok(QueryRows.decode(reader)) + 1 -> QueryResult.Err(reader.readString()) + else -> error("Invalid Result tag") + } + OneOffQueryResult(requestId, result) + } + 6 -> ReducerResultMsg( + requestId = reader.readU32(), + timestamp = Timestamp.decode(reader), + result = ReducerOutcome.decode(reader), + ) + 7 -> ProcedureResultMsg( + status = ProcedureStatus.decode(reader), + timestamp = Timestamp.decode(reader), + totalHostExecutionDuration = TimeDuration.decode(reader), + requestId = reader.readU32(), + ) + else -> error("Unknown ServerMessage tag: $tag") + } + } + + /** Decodes a [ServerMessage] from a raw byte array. */ + public fun decodeFromBytes(data: ByteArray, offset: Int = 0): ServerMessage { + val reader = BsatnReader(data, offset = offset) + return decode(reader) + } + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/transport/SpacetimeTransport.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/transport/SpacetimeTransport.kt new file mode 100644 index 00000000000..0c13b1c13bb --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/transport/SpacetimeTransport.kt @@ -0,0 +1,132 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.transport + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.CompressionMode +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ClientMessage +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ServerMessage +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.decompressMessage +import io.ktor.client.HttpClient +import io.ktor.client.plugins.websocket.webSocketSession +import io.ktor.client.request.header +import io.ktor.http.URLBuilder +import io.ktor.http.URLProtocol +import io.ktor.http.Url +import io.ktor.http.appendPathSegments +import io.ktor.websocket.Frame +import io.ktor.websocket.WebSocketSession +import io.ktor.websocket.close +import io.ktor.websocket.readBytes +import kotlinx.atomicfu.atomic +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow + +/** + * Transport abstraction for SpacetimeDB connections. + * Allows injecting a fake transport in tests. + */ +internal interface Transport { + suspend fun connect() + suspend fun send(message: ClientMessage) + fun incoming(): Flow + suspend fun disconnect() +} + +/** + * WebSocket transport for SpacetimeDB. + * Handles connection, message encoding/decoding, and compression. + */ +internal class SpacetimeTransport( + private val client: HttpClient, + private val baseUrl: String, + private val nameOrAddress: String, + private val connectionId: ConnectionId, + private val authToken: String? = null, + private val compression: CompressionMode = CompressionMode.GZIP, + private val lightMode: Boolean = false, + private val confirmedReads: Boolean? = null, +) : Transport { + private val _session = atomic(null) + + internal companion object { + /** WebSocket sub-protocol identifier for BSATN v2. */ + const val WS_PROTOCOL: String = "v2.bsatn.spacetimedb" + } + + + + /** + * Connects to the SpacetimeDB WebSocket endpoint. + * Passes the auth token as a Bearer Authorization header on the WebSocket connection. + */ + override suspend fun connect() { + val wsUrl = buildWsUrl() + + _session.value = client.webSocketSession(wsUrl) { + header("Sec-WebSocket-Protocol", WS_PROTOCOL) + if (authToken != null) { + header("Authorization", "Bearer $authToken") + } + } + } + + /** + * Sends a [ClientMessage] over the WebSocket as a BSATN-encoded binary frame. + */ + override suspend fun send(message: ClientMessage) { + val writer = BsatnWriter() + message.encode(writer) + val encoded = writer.toByteArray() + val ws = _session.value ?: error("Not connected") + ws.send(Frame.Binary(true, encoded)) + } + + /** + * Returns a Flow of ServerMessages received from the WebSocket. + * Handles decompression (prefix byte) then BSATN decoding. + */ + override fun incoming(): Flow = flow { + val ws = _session.value ?: error("Not connected") + // On clean close, the for-loop exits normally (hasNext() returns false). + // On abnormal close, hasNext() throws the original cause (e.g. IOException), + // which propagates to DbConnection's error handling path. + for (frame in ws.incoming) { + if (frame is Frame.Binary) { + val raw = frame.readBytes() + val decompressed = decompressMessage(raw) + val message = ServerMessage.decodeFromBytes(decompressed.data, decompressed.offset) + emit(message) + } + } + } + + /** Closes the WebSocket session, if one is open. */ + override suspend fun disconnect() { + val ws = _session.getAndSet(null) + ws?.close() + } + + private fun buildWsUrl(): String { + val base = Url(baseUrl) + return URLBuilder(base).apply { + protocol = when (base.protocol) { + URLProtocol.HTTPS -> URLProtocol.WSS + URLProtocol.HTTP -> URLProtocol.WS + URLProtocol.WSS -> URLProtocol.WSS + URLProtocol.WS -> URLProtocol.WS + else -> throw IllegalArgumentException( + "Unsupported protocol '${base.protocol.name}'. Use http://, https://, ws://, or wss://" + ) + } + appendPathSegments("v1", "database", nameOrAddress, "subscribe") + parameters.append("connection_id", connectionId.toHexString()) + parameters.append("compression", compression.wireValue) + if (lightMode) { + parameters.append("light", "true") + } + if (confirmedReads != null) { + parameters.append("confirmed", confirmedReads.toString()) + } + }.buildString() + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/ConnectionId.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/ConnectionId.kt new file mode 100644 index 00000000000..10f42d6cc04 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/ConnectionId.kt @@ -0,0 +1,41 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.BigInteger +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.parseHexString +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.randomBigInteger +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.toHexString + +/** A 128-bit connection identifier in SpacetimeDB. */ +public data class ConnectionId(val data: BigInteger) { + /** Encodes this value to BSATN. */ + public fun encode(writer: BsatnWriter): Unit = writer.writeU128(data) + /** Returns this connection ID as a 32-character lowercase hex string. */ + public fun toHexString(): String = data.toHexString(16) // U128 = 16 bytes = 32 hex chars + override fun toString(): String = toHexString() + /** Whether this connection ID is all zeros. */ + public fun isZero(): Boolean = data == BigInteger.ZERO + /** + * Returns the 16-byte little-endian representation, matching BSATN wire format. + */ + public fun toByteArray(): ByteArray = data.toLeBytesFixedWidth(16) + + public companion object { + /** Decodes a [ConnectionId] from BSATN. */ + public fun decode(reader: BsatnReader): ConnectionId = ConnectionId(reader.readU128()) + /** Returns a zero [ConnectionId]. */ + public fun zero(): ConnectionId = ConnectionId(BigInteger.ZERO) + /** Returns `null` if the given [ConnectionId] is zero, otherwise returns it unchanged. */ + public fun nullIfZero(addr: ConnectionId): ConnectionId? = if (addr.isZero()) null else addr + /** Returns a random [ConnectionId]. */ + public fun random(): ConnectionId = ConnectionId(randomBigInteger(16)) /* 16 bytes = 128 bits */ + /** Parses a [ConnectionId] from a hex string. */ + public fun fromHexString(hex: String): ConnectionId = ConnectionId(parseHexString(hex)) + /** Parses a [ConnectionId] from a hex string, returning `null` if parsing fails or the result is zero. */ + public fun fromHexStringOrNull(hex: String): ConnectionId? { + val id = try { fromHexString(hex) } catch (_: Exception) { return null } + return nullIfZero(id) + } + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/Identity.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/Identity.kt new file mode 100644 index 00000000000..95118ef6bf9 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/Identity.kt @@ -0,0 +1,30 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.BigInteger +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.parseHexString +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.toHexString + +/** A 256-bit identity that uniquely identifies a user in SpacetimeDB. */ +public data class Identity(val data: BigInteger) : Comparable { + override fun compareTo(other: Identity): Int = data.compareTo(other.data) + /** Encodes this value to BSATN. */ + public fun encode(writer: BsatnWriter): Unit = writer.writeU256(data) + /** Returns this identity as a 64-character lowercase hex string. */ + public fun toHexString(): String = data.toHexString(32) // U256 = 32 bytes = 64 hex chars + /** + * Returns the 32-byte little-endian representation, matching BSATN wire format. + */ + public fun toByteArray(): ByteArray = data.toLeBytesFixedWidth(32) + override fun toString(): String = toHexString() + + public companion object { + /** Decodes an [Identity] from BSATN. */ + public fun decode(reader: BsatnReader): Identity = Identity(reader.readU256()) + /** Parses an [Identity] from a hex string. */ + public fun fromHexString(hex: String): Identity = Identity(parseHexString(hex)) + /** Returns a zero [Identity]. */ + public fun zero(): Identity = Identity(BigInteger.ZERO) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/ScheduleAt.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/ScheduleAt.kt new file mode 100644 index 00000000000..0c55b650a30 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/ScheduleAt.kt @@ -0,0 +1,48 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import kotlin.time.Duration +import kotlin.time.Instant + +/** Specifies when a scheduled reducer should fire: at a fixed time or after an interval. */ +public sealed interface ScheduleAt { + /** Schedule by repeating interval. */ + public data class Interval(val duration: TimeDuration) : ScheduleAt + /** Schedule at a specific point in time. */ + public data class Time(val timestamp: Timestamp) : ScheduleAt + + /** Encodes this value to BSATN. */ + public fun encode(writer: BsatnWriter) { + when (this) { + is Interval -> { + writer.writeSumTag(INTERVAL_TAG) + duration.encode(writer) + } + + is Time -> { + writer.writeSumTag(TIME_TAG) + timestamp.encode(writer) + } + } + } + + public companion object { + private const val INTERVAL_TAG: UByte = 0u + private const val TIME_TAG: UByte = 1u + + /** Creates a [ScheduleAt] from a repeating [interval]. */ + public fun interval(interval: Duration): ScheduleAt = Interval(TimeDuration(interval)) + /** Creates a [ScheduleAt] for a specific point in [time]. */ + public fun time(time: Instant): ScheduleAt = Time(Timestamp(time)) + + /** Decodes a [ScheduleAt] from BSATN. */ + public fun decode(reader: BsatnReader): ScheduleAt { + return when (val tag = reader.readSumTag().toInt()) { + 0 -> Interval(TimeDuration.decode(reader)) + 1 -> Time(Timestamp.decode(reader)) + else -> error("Unknown ScheduleAt tag: $tag") + } + } + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/SpacetimeUuid.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/SpacetimeUuid.kt new file mode 100644 index 00000000000..adbbb897613 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/SpacetimeUuid.kt @@ -0,0 +1,163 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.BigInteger +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Sign +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.toEpochMicroseconds +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.getAndUpdate +import kotlin.uuid.Uuid + +/** Thread-safe monotonic counter for UUID V7 generation. */ +public class Counter(value: Int = 0) { + private val _value = atomic(value) + internal fun getAndIncrement(): Int = + _value.getAndUpdate { (it + 1) and 0x7FFF_FFFF } +} + +/** UUID version detected from the version nibble. */ +public enum class UuidVersion { Nil, V4, V7, Max, Unknown } + +/** A UUID wrapper providing BSATN encoding and V4/V7 generation for SpacetimeDB. */ +public data class SpacetimeUuid(val data: Uuid) : Comparable { + override fun compareTo(other: SpacetimeUuid): Int { + val a = data.toByteArray() + val b = other.data.toByteArray() + for (i in a.indices) { + val cmp = (a[i].toInt() and 0xFF).compareTo(b[i].toInt() and 0xFF) + if (cmp != 0) return cmp + } + return 0 + } + /** Encodes this value to BSATN. */ + public fun encode(writer: BsatnWriter) { + val value = BigInteger.fromByteArray(data.toByteArray(), Sign.POSITIVE) + writer.writeU128(value) + } + + override fun toString(): String = data.toString() + + /** Returns this UUID as a 32-character lowercase hex string. */ + public fun toHexString(): String = data.toHexString() + + /** Returns the 16-byte big-endian representation of this UUID. */ + public fun toByteArray(): ByteArray = data.toByteArray() + + /** + * Extracts the 31-bit monotonic counter from a V7 UUID. + * + * UUID V7 byte layout: + * ``` + * Byte: 0 1 2 3 4 5 | 6 | 7 | 8 | 9 10 11 | 12 13 14 15 + * [--- timestamp ---][ver ][ctr ][var ][-- counter --] [-- random --] + * ``` + * - Bytes 0–5: 48-bit Unix timestamp in milliseconds + * - Byte 6: UUID version nibble (0x70 for V7) — **not** counter data, skipped + * - Byte 7: counter bits 30–23 + * - Byte 8: RFC 4122 variant bits (0x80) — **not** counter data, skipped + * - Bytes 9–11: counter bits 22–0 (bit 0 is in the high bit of the byte after byte 11) + * - Bytes 12–15: random + */ + public fun getCounter(): Int { + val b = data.toByteArray() + return ((b[7].toInt() and 0xFF) shl 23) or + ((b[9].toInt() and 0xFF) shl 15) or + ((b[10].toInt() and 0xFF) shl 7) or + ((b[11].toInt() and 0xFF) shr 1) + } + + /** Detects the UUID version from the version nibble in byte 6. */ + public fun getVersion(): UuidVersion { + if (data == Uuid.NIL) return UuidVersion.Nil + val bytes = data.toByteArray() + if (bytes.all { it == 0xFF.toByte() }) return UuidVersion.Max + return when ((bytes[6].toInt() shr 4) and 0x0F) { + 4 -> UuidVersion.V4 + 7 -> UuidVersion.V7 + else -> UuidVersion.Unknown + } + } + + public companion object { + /** The nil UUID (all zeros). */ + public val NIL: SpacetimeUuid = SpacetimeUuid(Uuid.NIL) + /** The max UUID (all ones). */ + public val MAX: SpacetimeUuid = SpacetimeUuid(Uuid.fromByteArray(ByteArray(16) { 0xFF.toByte() })) + + /** Decodes from BSATN. */ + public fun decode(reader: BsatnReader): SpacetimeUuid { + val value = reader.readU128() + val bytes = value.toByteArray() + val padded = if (bytes.size >= 16) bytes.copyOfRange(bytes.size - 16, bytes.size) + else ByteArray(16 - bytes.size) + bytes + return SpacetimeUuid(Uuid.fromByteArray(padded)) + } + + /** Generates a random V4 UUID using the platform's secure random. */ + public fun random(): SpacetimeUuid = SpacetimeUuid(Uuid.random()) + + /** Creates a V4 UUID from 16 random bytes, setting the version and variant bits. */ + public fun fromRandomBytesV4(bytes: ByteArray): SpacetimeUuid { + require(bytes.size == 16) { "UUID v4 requires exactly 16 bytes, got ${bytes.size}" } + val b = bytes.copyOf() + b[6] = ((b[6].toInt() and 0x0F) or 0x40).toByte() // version 4 + b[8] = ((b[8].toInt() and 0x3F) or 0x80).toByte() // variant RFC 4122 + return SpacetimeUuid(Uuid.fromByteArray(b)) + } + + /** + * Creates a V7 UUID with the given counter, timestamp, and random bytes. + * + * UUID V7 byte layout: + * ``` + * Byte: 0 1 2 3 4 5 | 6 | 7 | 8 | 9 10 11 | 12 13 14 15 + * [--- timestamp ---][ver ][ctr ][var ][-- counter --] [-- random --] + * ``` + * - Bytes 0–5: 48-bit Unix timestamp in milliseconds (big-endian) + * - Byte 6: UUID version nibble, fixed to `0x70` (V7) + * - Byte 7: counter bits 30–23 + * - Byte 8: RFC 4122 variant, fixed to `0x80` + * - Bytes 9–11: counter bits 22–0 (bit 0 stored in high bit of byte after 11) + * - Bytes 12–15: random bytes + * + * Bytes 6 and 8 hold fixed version/variant metadata and are **not** part of + * the counter, which is why [getCounter] skips them when reading back. + */ + public fun fromCounterV7(counter: Counter, now: Timestamp, randomBytes: ByteArray): SpacetimeUuid { + require(randomBytes.size >= 4) { "V7 UUID requires at least 4 random bytes, got ${randomBytes.size}" } + val counterVal = counter.getAndIncrement() + + val tsMs = now.instant.toEpochMicroseconds() / 1_000 + + val b = ByteArray(16) + // Bytes 0-5: 48-bit unix timestamp (ms), big-endian + b[0] = (tsMs shr 40).toByte() + b[1] = (tsMs shr 32).toByte() + b[2] = (tsMs shr 24).toByte() + b[3] = (tsMs shr 16).toByte() + b[4] = (tsMs shr 8).toByte() + b[5] = tsMs.toByte() + // Byte 6: version 7 (fixed — not counter data) + b[6] = 0x70.toByte() + // Byte 7: counter bits 30-23 + b[7] = ((counterVal shr 23) and 0xFF).toByte() + // Byte 8: variant RFC 4122 (fixed — not counter data) + b[8] = 0x80.toByte() + // Bytes 9-11: counter bits 22-0 + b[9] = ((counterVal shr 15) and 0xFF).toByte() + b[10] = ((counterVal shr 7) and 0xFF).toByte() + b[11] = ((counterVal and 0x7F) shl 1).toByte() + // Bytes 12-15: random bytes + b[12] = (randomBytes[0].toInt() and 0x7F).toByte() + b[13] = randomBytes[1] + b[14] = randomBytes[2] + b[15] = randomBytes[3] + + return SpacetimeUuid(Uuid.fromByteArray(b)) + } + + /** Parses a UUID from its standard string representation (e.g. `550e8400-e29b-41d4-a716-446655440000`). */ + public fun parse(str: String): SpacetimeUuid = SpacetimeUuid(Uuid.parse(str)) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/TimeDuration.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/TimeDuration.kt new file mode 100644 index 00000000000..ae215593c5f --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/TimeDuration.kt @@ -0,0 +1,44 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import kotlin.math.abs +import kotlin.time.Duration +import kotlin.time.Duration.Companion.microseconds +import kotlin.time.Duration.Companion.milliseconds + +/** A duration with microsecond precision, backed by [Duration]. */ +public data class TimeDuration(val duration: Duration) : Comparable { + /** Encodes this value to BSATN. */ + public fun encode(writer: BsatnWriter): Unit = writer.writeI64(duration.inWholeMicroseconds) + /** This duration in whole microseconds. */ + public val micros: Long get() = duration.inWholeMicroseconds + /** This duration in whole milliseconds. */ + public val millis: Long get() = duration.inWholeMilliseconds + + /** Returns the sum of this duration and [other]. */ + public operator fun plus(other: TimeDuration): TimeDuration = + TimeDuration(duration + other.duration) + + /** Returns the difference between this duration and [other]. */ + public operator fun minus(other: TimeDuration): TimeDuration = + TimeDuration(duration - other.duration) + + override operator fun compareTo(other: TimeDuration): Int = + duration.compareTo(other.duration) + + override fun toString(): String { + val sign = if (duration.inWholeMicroseconds >= 0) "+" else "-" + val abs = abs(duration.inWholeMicroseconds) + val secs = abs / 1_000_000 + val frac = abs % 1_000_000 + return "$sign$secs.${frac.toString().padStart(6, '0')}" + } + + public companion object { + /** Decodes a [TimeDuration] from BSATN. */ + public fun decode(reader: BsatnReader): TimeDuration = TimeDuration(reader.readI64().microseconds) + /** Creates a [TimeDuration] from milliseconds. */ + public fun fromMillis(millis: Long): TimeDuration = TimeDuration(millis.milliseconds) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/Timestamp.kt b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/Timestamp.kt new file mode 100644 index 00000000000..e0d087c65be --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/type/Timestamp.kt @@ -0,0 +1,75 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.fromEpochMicroseconds +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.toEpochMicroseconds +import kotlin.time.Clock +import kotlin.time.Duration.Companion.microseconds +import kotlin.time.Instant + +/** A microsecond-precision timestamp backed by [Instant]. */ +public data class Timestamp(val instant: Instant) : Comparable { + public companion object { + /** The Unix epoch (1970-01-01T00:00:00Z). */ + public val UNIX_EPOCH: Timestamp = Timestamp(Instant.fromEpochMilliseconds(0)) + + /** Returns the current system time as a [Timestamp]. */ + public fun now(): Timestamp = Timestamp(Clock.System.now()) + + /** Decodes a [Timestamp] from BSATN. */ + public fun decode(reader: BsatnReader): Timestamp = + Timestamp(Instant.fromEpochMicroseconds(reader.readI64())) + + /** Creates a [Timestamp] from microseconds since the Unix epoch. */ + public fun fromEpochMicroseconds(micros: Long): Timestamp = + Timestamp(Instant.fromEpochMicroseconds(micros)) + + /** Creates a [Timestamp] from milliseconds since the Unix epoch. */ + public fun fromMillis(millis: Long): Timestamp = + Timestamp(Instant.fromEpochMilliseconds(millis)) + } + + /** Encodes this value to BSATN. */ + public fun encode(writer: BsatnWriter) { + writer.writeI64(instant.toEpochMicroseconds()) + } + + /** Microseconds since Unix epoch */ + public val microsSinceUnixEpoch: Long + get() = instant.toEpochMicroseconds() + + /** Milliseconds since Unix epoch */ + public val millisSinceUnixEpoch: Long + get() = instant.toEpochMilliseconds() + + /** Duration since another Timestamp */ + public fun since(other: Timestamp): TimeDuration = + TimeDuration((microsSinceUnixEpoch - other.microsSinceUnixEpoch).microseconds) + + /** Returns a new [Timestamp] offset forward by [duration]. */ + public operator fun plus(duration: TimeDuration): Timestamp = + fromEpochMicroseconds(microsSinceUnixEpoch + duration.micros) + + /** Returns a new [Timestamp] offset backward by [duration]. */ + public operator fun minus(duration: TimeDuration): Timestamp = + fromEpochMicroseconds(microsSinceUnixEpoch - duration.micros) + + /** Returns the duration between this timestamp and [other]. */ + public operator fun minus(other: Timestamp): TimeDuration = + TimeDuration((microsSinceUnixEpoch - other.microsSinceUnixEpoch).microseconds) + + override operator fun compareTo(other: Timestamp): Int = + microsSinceUnixEpoch.compareTo(other.microsSinceUnixEpoch) + + /** Returns this timestamp as an ISO 8601 string with microsecond precision. */ + public fun toISOString(): String { + val micros = microsSinceUnixEpoch + val seconds = micros.floorDiv(1_000_000L) + val microFraction = micros.mod(1_000_000L).toInt() + val base = Instant.fromEpochSeconds(seconds).toString().removeSuffix("Z") + return "$base.${microFraction.toString().padStart(6, '0')}Z" + } + + override fun toString(): String = toISOString() +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BigIntegerTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BigIntegerTest.kt new file mode 100644 index 00000000000..5b943c01213 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BigIntegerTest.kt @@ -0,0 +1,559 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotEquals +import kotlin.test.assertTrue + +class BigIntegerTest { + + // ---- Construction from Long ---- + + @Test + fun `construct from zero`() { + assertEquals("0", BigInteger(0L).toString()) + assertEquals(0, BigInteger(0L).signum()) + } + + @Test + fun `construct from positive long`() { + assertEquals("42", BigInteger(42L).toString()) + assertEquals("9223372036854775807", BigInteger(Long.MAX_VALUE).toString()) + } + + @Test + fun `construct from negative long`() { + assertEquals("-1", BigInteger(-1L).toString()) + assertEquals("-42", BigInteger(-42L).toString()) + assertEquals("-9223372036854775808", BigInteger(Long.MIN_VALUE).toString()) + } + + @Test + fun `construct from int`() { + assertEquals("42", BigInteger(42).toString()) + assertEquals("-1", BigInteger(-1).toString()) + } + + // ---- Constants ---- + + @Test + fun constants() { + assertEquals("0", BigInteger.ZERO.toString()) + assertEquals("1", BigInteger.ONE.toString()) + assertEquals("2", BigInteger.TWO.toString()) + assertEquals("10", BigInteger.TEN.toString()) + } + + // ---- fromULong ---- + + @Test + fun `from u long zero`() { + assertEquals(BigInteger.ZERO, BigInteger.fromULong(0UL)) + } + + @Test + fun `from u long small`() { + assertEquals(BigInteger(42L), BigInteger.fromULong(42UL)) + } + + @Test + fun `from u long max`() { + // ULong.MAX_VALUE = 2^64 - 1 = 18446744073709551615 + val v = BigInteger.fromULong(ULong.MAX_VALUE) + assertEquals("18446744073709551615", v.toString()) + assertEquals(1, v.signum()) + } + + @Test + fun `from u long high bit set`() { + // 2^63 = 9223372036854775808 (high bit of Long set, but unsigned) + val v = BigInteger.fromULong(9223372036854775808UL) + assertEquals("9223372036854775808", v.toString()) + assertEquals(1, v.signum()) + } + + // ---- parseString decimal ---- + + @Test + fun `parse decimal zero`() { + assertEquals(BigInteger.ZERO, BigInteger.parseString("0")) + } + + @Test + fun `parse decimal positive`() { + assertEquals(BigInteger(42L), BigInteger.parseString("42")) + } + + @Test + fun `parse decimal negative`() { + assertEquals(BigInteger(-42L), BigInteger.parseString("-42")) + } + + @Test + fun `parse decimal large positive`() { + // 2^127 - 1 = I128 max + val s = "170141183460469231731687303715884105727" + val v = BigInteger.parseString(s) + assertEquals(s, v.toString()) + } + + @Test + fun `parse decimal large negative`() { + // -2^127 = I128 min + val s = "-170141183460469231731687303715884105728" + val v = BigInteger.parseString(s) + assertEquals(s, v.toString()) + } + + @Test + fun `parse decimal u256 max`() { + // 2^256 - 1 + val s = "115792089237316195423570985008687907853269984665640564039457584007913129639935" + val v = BigInteger.parseString(s) + assertEquals(s, v.toString()) + } + + // ---- parseString hex ---- + + @Test + fun `parse hex zero`() { + assertEquals(BigInteger.ZERO, BigInteger.parseString("0", 16)) + } + + @Test + fun `parse hex small`() { + assertEquals(BigInteger(255L), BigInteger.parseString("ff", 16)) + assertEquals(BigInteger(256L), BigInteger.parseString("100", 16)) + } + + @Test + fun `parse hex upper case`() { + assertEquals(BigInteger(255L), BigInteger.parseString("FF", 16)) + } + + @Test + fun `parse hex negative`() { + assertEquals(BigInteger(-255L), BigInteger.parseString("-ff", 16)) + } + + @Test + fun `parse hex large`() { + // 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF = U128 max + val v = BigInteger.parseString("ffffffffffffffffffffffffffffffff", 16) + assertEquals("340282366920938463463374607431768211455", v.toString()) + } + + // ---- toString hex ---- + + @Test + fun `to string hex zero`() { + assertEquals("0", BigInteger.ZERO.toString(16)) + } + + @Test + fun `to string hex positive`() { + assertEquals("ff", BigInteger(255L).toString(16)) + assertEquals("100", BigInteger(256L).toString(16)) + assertEquals("1", BigInteger(1L).toString(16)) + } + + @Test + fun `to string hex negative`() { + assertEquals("-1", BigInteger(-1L).toString(16)) + assertEquals("-ff", BigInteger(-255L).toString(16)) + } + + @Test + fun `hex round trip`() { + val original = "deadbeef01234567890abcdef" + val v = BigInteger.parseString(original, 16) + assertEquals(original, v.toString(16)) + } + + // ---- Arithmetic: shl ---- + + @Test + fun `shl zero`() { + assertEquals(BigInteger(1L), BigInteger(1L).shl(0)) + } + + @Test + fun `shl by one`() { + assertEquals(BigInteger(2L), BigInteger(1L).shl(1)) + assertEquals(BigInteger(254L), BigInteger(127L).shl(1)) + } + + @Test + fun `shl by eight`() { + assertEquals(BigInteger(256L), BigInteger(1L).shl(8)) + } + + @Test + fun `shl large`() { + // 1 << 127 = 2^127 + val v = BigInteger.ONE.shl(127) + assertEquals("170141183460469231731687303715884105728", v.toString()) + } + + @Test + fun `shl negative`() { + // -1 << 8 = -256 + assertEquals(BigInteger(-256L), BigInteger(-1L).shl(8)) + // -1 << 1 = -2 + assertEquals(BigInteger(-2L), BigInteger(-1L).shl(1)) + } + + @Test + fun `shl zero value`() { + assertEquals(BigInteger.ZERO, BigInteger.ZERO.shl(100)) + } + + // ---- Arithmetic: add ---- + + @Test + fun `add positive`() { + assertEquals(BigInteger(3L), BigInteger(1L).add(BigInteger(2L))) + } + + @Test + fun `add negative`() { + assertEquals(BigInteger(-3L), BigInteger(-1L).add(BigInteger(-2L))) + } + + @Test + fun `add mixed`() { + assertEquals(BigInteger.ZERO, BigInteger(1L).add(BigInteger(-1L))) + } + + @Test + fun `add large`() { + // (2^127 - 1) + 1 = 2^127 + val max = BigInteger.ONE.shl(127) - BigInteger.ONE + val result = max + BigInteger.ONE + assertEquals(BigInteger.ONE.shl(127), result) + } + + // ---- Arithmetic: subtract ---- + + @Test + fun `subtract positive`() { + assertEquals(BigInteger(-1L), BigInteger(1L) - BigInteger(2L)) + } + + @Test + fun `subtract same`() { + assertEquals(BigInteger.ZERO, BigInteger(42L) - BigInteger(42L)) + } + + // ---- Arithmetic: negate ---- + + @Test + fun `negate positive`() { + assertEquals(BigInteger(-42L), -BigInteger(42L)) + } + + @Test + fun `negate negative`() { + assertEquals(BigInteger(42L), -BigInteger(-42L)) + } + + @Test + fun `negate zero`() { + assertEquals(BigInteger.ZERO, -BigInteger.ZERO) + } + + @Test + fun `negate long min`() { + // -(Long.MIN_VALUE) = Long.MAX_VALUE + 1 = 9223372036854775808 + val v = -BigInteger(Long.MIN_VALUE) + assertEquals("9223372036854775808", v.toString()) + assertEquals(1, v.signum()) + } + + // ---- signum ---- + + @Test + fun `signum values`() { + assertEquals(0, BigInteger.ZERO.signum()) + assertEquals(1, BigInteger.ONE.signum()) + assertEquals(-1, BigInteger(-1L).signum()) + } + + // ---- compareTo ---- + + @Test + fun `compare to same value`() { + assertEquals(0, BigInteger(42L).compareTo(BigInteger(42L))) + } + + @Test + fun `compare to positive`() { + assertTrue(BigInteger(1L) < BigInteger(2L)) + assertTrue(BigInteger(2L) > BigInteger(1L)) + } + + @Test + fun `compare to negative`() { + assertTrue(BigInteger(-2L) < BigInteger(-1L)) + } + + @Test + fun `compare to cross sign`() { + assertTrue(BigInteger(-1L) < BigInteger(1L)) + assertTrue(BigInteger(1L) > BigInteger(-1L)) + assertTrue(BigInteger(-1L) < BigInteger.ZERO) + assertTrue(BigInteger.ZERO < BigInteger.ONE) + } + + @Test + fun `compare to large values`() { + val a = BigInteger.ONE.shl(127) + val b = BigInteger.ONE.shl(127) - BigInteger.ONE + assertTrue(a > b) + assertTrue(b < a) + } + + // ---- equals and hashCode ---- + + @Test + fun `equals identical`() { + assertEquals(BigInteger(42L), BigInteger(42L)) + } + + @Test + fun `equals from different paths`() { + // Same value constructed differently should be equal + val a = BigInteger.parseString("255") + val b = BigInteger.parseString("ff", 16) + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + } + + @Test + fun `not equals different values`() { + assertNotEquals(BigInteger(1L), BigInteger(2L)) + } + + // ---- toByteArray (BE two's complement) ---- + + @Test + fun `to byte array zero`() { + val bytes = BigInteger.ZERO.toByteArray() + assertEquals(1, bytes.size) + assertEquals(0.toByte(), bytes[0]) + } + + @Test + fun `to byte array positive`() { + val bytes = BigInteger(1L).toByteArray() + assertEquals(1, bytes.size) + assertEquals(1.toByte(), bytes[0]) + } + + @Test + fun `to byte array negative`() { + // -1 in BE two's complement = [0xFF] + val bytes = BigInteger(-1L).toByteArray() + assertEquals(1, bytes.size) + assertEquals(0xFF.toByte(), bytes[0]) + } + + @Test + fun `to byte array128`() { + // 128 needs 2 bytes in BE: [0x00, 0x80] + val bytes = BigInteger(128L).toByteArray() + assertEquals(2, bytes.size) + assertEquals(0x00.toByte(), bytes[0]) + assertEquals(0x80.toByte(), bytes[1]) + } + + // ---- fromLeBytes / toLeBytesFixedWidth round-trip ---- + + @Test + fun `le bytes round trip16`() { + val values = listOf(BigInteger.ZERO, BigInteger.ONE, BigInteger(-1L), + BigInteger.ONE.shl(127) - BigInteger.ONE, // I128 max + -BigInteger.ONE.shl(127)) // I128 min + for (v in values) { + val le = v.toLeBytesFixedWidth(16) + assertEquals(16, le.size) + val restored = BigInteger.fromLeBytes(le, 0, 16) + assertEquals(v, restored, "LE round-trip failed for $v") + } + } + + @Test + fun `le bytes round trip32`() { + val values = listOf(BigInteger.ZERO, BigInteger.ONE, BigInteger(-1L), + BigInteger.ONE.shl(255) - BigInteger.ONE, // I256 max + -BigInteger.ONE.shl(255)) // I256 min + for (v in values) { + val le = v.toLeBytesFixedWidth(32) + assertEquals(32, le.size) + val restored = BigInteger.fromLeBytes(le, 0, 32) + assertEquals(v, restored, "LE round-trip failed for $v") + } + } + + @Test + fun `from le bytes unsigned max u128`() { + // All 0xFF bytes = U128 max + val le = ByteArray(16) { 0xFF.toByte() } + val v = BigInteger.fromLeBytesUnsigned(le, 0, 16) + assertEquals(1, v.signum()) + assertEquals("340282366920938463463374607431768211455", v.toString()) + } + + // ---- fromByteArray with Sign ---- + + @Test + fun `from byte array positive`() { + // BE magnitude [0xFF] with POSITIVE sign = 255 + val v = BigInteger.fromByteArray(byteArrayOf(0xFF.toByte()), Sign.POSITIVE) + assertEquals(BigInteger(255L), v) + } + + @Test + fun `from byte array negative`() { + val v = BigInteger.fromByteArray(byteArrayOf(0x01), Sign.NEGATIVE) + assertEquals(BigInteger(-1L), v) + } + + @Test + fun `from byte array zero`() { + assertEquals(BigInteger.ZERO, BigInteger.fromByteArray(byteArrayOf(0), Sign.ZERO)) + } + + // ---- fitsInSignedBytes / fitsInUnsignedBytes ---- + + @Test + fun `fits in signed bytes i128`() { + val max = BigInteger.ONE.shl(127) - BigInteger.ONE + val min = -BigInteger.ONE.shl(127) + assertTrue(max.fitsInSignedBytes(16)) + assertTrue(min.fitsInSignedBytes(16)) + + val overflow = BigInteger.ONE.shl(127) + assertTrue(!overflow.fitsInSignedBytes(16)) + } + + @Test + fun `fits in unsigned bytes u128`() { + val max = BigInteger.ONE.shl(128) - BigInteger.ONE + assertTrue(max.fitsInUnsignedBytes(16)) + + val overflow = BigInteger.ONE.shl(128) + assertTrue(!overflow.fitsInUnsignedBytes(16)) + + assertTrue(!BigInteger(-1L).fitsInUnsignedBytes(16)) + } + + // ---- Chunk boundary values (128-bit) ---- + + @Test + fun `chunk boundary128`() { + val ONE = BigInteger.ONE + val values = listOf( + ONE.shl(63) - ONE, // 2^63 - 1 + ONE.shl(63), // 2^63 + ONE.shl(64) - ONE, // 2^64 - 1 + ONE.shl(64), // 2^64 + ONE.shl(64) + ONE, // 2^64 + 1 + ) + for (v in values) { + val le = v.toLeBytesFixedWidth(16) + val restored = BigInteger.fromLeBytesUnsigned(le, 0, 16) + assertEquals(v, restored, "Chunk boundary failed for $v") + } + } + + // ---- Chunk boundary values (256-bit) ---- + + @Test + fun `chunk boundary256`() { + val ONE = BigInteger.ONE + val values = listOf( + ONE.shl(63), + ONE.shl(64), + ONE.shl(127), + ONE.shl(128), + ONE.shl(191), + ONE.shl(192), + ONE.shl(255), + ) + for (v in values) { + val le = v.toLeBytesFixedWidth(32) + val restored = BigInteger.fromLeBytesUnsigned(le, 0, 32) + assertEquals(v, restored, "256-bit chunk boundary failed for $v") + } + } + + // ---- Negative LE round-trips (signed) ---- + + @Test + fun `negative le bytes round trip`() { + val ONE = BigInteger.ONE + val values = listOf( + BigInteger(-2), + -ONE.shl(63), + -ONE.shl(64), + -ONE.shl(64) - ONE, + -ONE.shl(127), + ) + for (v in values) { + val le = v.toLeBytesFixedWidth(16) + val restored = BigInteger.fromLeBytes(le, 0, 16) + assertEquals(v, restored, "Negative LE round-trip failed for $v") + } + } + + // ---- Decimal toString round-trip for large values ---- + + @Test + fun `decimal round trip large values`() { + val values = listOf( + "170141183460469231731687303715884105727", // I128 max + "-170141183460469231731687303715884105728", // I128 min + "340282366920938463463374607431768211455", // U128 max + "57896044618658097711785492504343953926634992332820282019728792003956564819967", // I256 max + "-57896044618658097711785492504343953926634992332820282019728792003956564819968", // I256 min + "115792089237316195423570985008687907853269984665640564039457584007913129639935", // U256 max + ) + for (s in values) { + val v = BigInteger.parseString(s) + assertEquals(s, v.toString(), "Decimal round-trip failed for $s") + } + } + + // ---- writeLeBytes ---- + + @Test + fun `write le bytes directly`() { + val v = BigInteger(0x0102030405060708L) + val dest = ByteArray(16) + v.writeLeBytes(dest, 0, 16) + assertEquals(0x08.toByte(), dest[0]) + assertEquals(0x07.toByte(), dest[1]) + assertEquals(0x06.toByte(), dest[2]) + assertEquals(0x05.toByte(), dest[3]) + assertEquals(0x04.toByte(), dest[4]) + assertEquals(0x03.toByte(), dest[5]) + assertEquals(0x02.toByte(), dest[6]) + assertEquals(0x01.toByte(), dest[7]) + // Rest should be zero-padded + for (i in 8 until 16) { + assertEquals(0.toByte(), dest[i], "Byte at $i should be 0") + } + } + + @Test + fun `write le bytes negative`() { + val v = BigInteger(-1L) + val dest = ByteArray(16) + v.writeLeBytes(dest, 0, 16) + // -1 in 16 bytes LE = all 0xFF + for (i in 0 until 16) { + assertEquals(0xFF.toByte(), dest[i], "Byte at $i should be 0xFF") + } + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BsatnRoundTripTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BsatnRoundTripTest.kt new file mode 100644 index 00000000000..bd968aa093b --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BsatnRoundTripTest.kt @@ -0,0 +1,498 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class BsatnRoundTripTest { + private fun roundTrip(write: (BsatnWriter) -> Unit, read: (BsatnReader) -> Any?): Any? { + val writer = BsatnWriter() + write(writer) + val reader = BsatnReader(writer.toByteArray()) + val result = read(reader) + assertEquals(0, reader.remaining, "All bytes should be consumed") + return result + } + + // ---- Bool ---- + + @Test + fun `bool true`() { + val result = roundTrip({ it.writeBool(true) }, { it.readBool() }) + assertTrue(result as Boolean) + } + + @Test + fun `bool false`() { + val result = roundTrip({ it.writeBool(false) }, { it.readBool() }) + assertFalse(result as Boolean) + } + + // ---- I8 / U8 ---- + + @Test + fun `i8 round trip`() { + for (v in listOf(Byte.MIN_VALUE, -1, 0, 1, Byte.MAX_VALUE)) { + val result = roundTrip({ it.writeI8(v) }, { it.readI8() }) + assertEquals(v, result) + } + } + + @Test + fun `u8 round trip`() { + for (v in listOf(0u, 1u, 127u, 255u)) { + val result = roundTrip({ it.writeU8(v.toUByte()) }, { it.readU8() }) + assertEquals(v.toUByte(), result) + } + } + + // ---- I16 / U16 ---- + + @Test + fun `i16 round trip`() { + for (v in listOf(Short.MIN_VALUE, -1, 0, 1, Short.MAX_VALUE)) { + val result = roundTrip({ it.writeI16(v) }, { it.readI16() }) + assertEquals(v, result) + } + } + + @Test + fun `u16 round trip`() { + for (v in listOf(0u, 1u, 32767u, 65535u)) { + val result = roundTrip({ it.writeU16(v.toUShort()) }, { it.readU16() }) + assertEquals(v.toUShort(), result) + } + } + + // ---- I32 / U32 ---- + + @Test + fun `i32 round trip`() { + for (v in listOf(Int.MIN_VALUE, -1, 0, 1, Int.MAX_VALUE)) { + val result = roundTrip({ it.writeI32(v) }, { it.readI32() }) + assertEquals(v, result) + } + } + + @Test + fun `u32 round trip`() { + for (v in listOf(0u, 1u, UInt.MAX_VALUE)) { + val result = roundTrip({ it.writeU32(v) }, { it.readU32() }) + assertEquals(v, result) + } + } + + // ---- I64 / U64 ---- + + @Test + fun `i64 round trip`() { + for (v in listOf(Long.MIN_VALUE, -1L, 0L, 1L, Long.MAX_VALUE)) { + val result = roundTrip({ it.writeI64(v) }, { it.readI64() }) + assertEquals(v, result) + } + } + + @Test + fun `u64 round trip`() { + for (v in listOf(0uL, 1uL, ULong.MAX_VALUE)) { + val result = roundTrip({ it.writeU64(v) }, { it.readU64() }) + assertEquals(v, result) + } + } + + // ---- F32 / F64 ---- + + @Test + fun `f32 round trip`() { + for (v in listOf(0.0f, -1.5f, Float.MAX_VALUE, Float.MIN_VALUE, Float.NaN, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY)) { + val writer = BsatnWriter() + writer.writeF32(v) + val reader = BsatnReader(writer.toByteArray()) + val result = reader.readF32() + if (v.isNaN()) { + assertTrue(result.isNaN(), "Expected NaN") + } else { + assertEquals(v, result) + } + } + } + + @Test + fun `f64 round trip`() { + for (v in listOf(0.0, -1.5, Double.MAX_VALUE, Double.MIN_VALUE, Double.NaN, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY)) { + val writer = BsatnWriter() + writer.writeF64(v) + val reader = BsatnReader(writer.toByteArray()) + val result = reader.readF64() + if (v.isNaN()) { + assertTrue(result.isNaN(), "Expected NaN") + } else { + assertEquals(v, result) + } + } + } + + // ---- I128 / U128 ---- + + @Test + fun `i128 round trip`() { + val values = listOf( + BigInteger.ZERO, + BigInteger.ONE, + BigInteger(-1), + BigInteger.parseString("170141183460469231731687303715884105727"), // I128 max (2^127 - 1) + BigInteger.parseString("-170141183460469231731687303715884105728"), // I128 min (-2^127) + ) + for (v in values) { + val result = roundTrip({ it.writeI128(v) }, { it.readI128() }) + assertEquals(v, result, "I128 round-trip failed for $v") + } + } + + @Test + fun `i128 negative edge cases`() { + val ONE = BigInteger.ONE + val values = listOf( + BigInteger(-2), // 0xFF...FE — near -1 + -ONE.shl(63), // -2^63: p0=Long.MIN_VALUE as unsigned, p1=-1 + -ONE.shl(63) + ONE, // -2^63 + 1: p0 high bit set + -ONE.shl(63) - ONE, // -2^63 - 1: borrow from p1 into p0 + -ONE.shl(64), // -2^64: p0=0, p1=-1 — exact chunk boundary + -ONE.shl(64) + ONE, // -2^64 + 1: p0 = ULong.MAX_VALUE, p1 = -2 + -ONE.shl(64) - ONE, // -2^64 - 1: just past chunk boundary + BigInteger.parseString("-9223372036854775808"), // -2^63 as decimal + BigInteger.parseString("-18446744073709551616"), // -2^64 as decimal + ) + for (v in values) { + val result = roundTrip({ it.writeI128(v) }, { it.readI128() }) + assertEquals(v, result, "I128 negative edge case failed for $v") + } + } + + @Test + fun `i128 chunk boundary values`() { + val ONE = BigInteger.ONE + val values = listOf( + ONE.shl(63) - ONE, // 2^63 - 1 = Long.MAX_VALUE in p0 + ONE.shl(63), // 2^63: p0 bit 63 set (unsigned), p1=0 + ONE.shl(64) - ONE, // 2^64 - 1: p0 = all ones (unsigned), p1 = 0 + ONE.shl(64), // 2^64: p0 = 0, p1 = 1 + ONE.shl(64) + ONE, // 2^64 + 1: p0 = 1, p1 = 1 + ) + for (v in values) { + val result = roundTrip({ it.writeI128(v) }, { it.readI128() }) + assertEquals(v, result, "I128 chunk boundary failed for $v") + } + } + + @Test + fun `u128 round trip`() { + val values = listOf( + BigInteger.ZERO, + BigInteger.ONE, + BigInteger.parseString("340282366920938463463374607431768211455"), // U128 max + ) + for (v in values) { + val result = roundTrip({ it.writeU128(v) }, { it.readU128() }) + assertEquals(v, result, "U128 round-trip failed for $v") + } + } + + @Test + fun `u128 chunk boundary values`() { + val ONE = BigInteger.ONE + val values = listOf( + ONE.shl(63) - ONE, // 2^63 - 1: p0 just below Long sign bit + ONE.shl(63), // 2^63: p0 has high bit set (read as negative Long) + ONE.shl(64) - ONE, // 2^64 - 1: p0 all ones, p1 = 0 + ONE.shl(64), // 2^64: p0 = 0, p1 = 1 + ONE.shl(127), // 2^127: p1 high bit set (read as negative Long) + ) + for (v in values) { + val result = roundTrip({ it.writeU128(v) }, { it.readU128() }) + assertEquals(v, result, "U128 chunk boundary failed for $v") + } + } + + // ---- I256 / U256 ---- + + @Test + fun `i256 round trip`() { + val values = listOf( + BigInteger.ZERO, + BigInteger.ONE, + BigInteger(-1), + // I256 max: 2^255 - 1 + BigInteger.parseString("57896044618658097711785492504343953926634992332820282019728792003956564819967"), + // I256 min: -2^255 + BigInteger.parseString("-57896044618658097711785492504343953926634992332820282019728792003956564819968"), + ) + for (v in values) { + val result = roundTrip({ it.writeI256(v) }, { it.readI256() }) + assertEquals(v, result, "I256 round-trip failed for $v") + } + } + + @Test + fun `i256 negative edge cases`() { + val ONE = BigInteger.ONE + val values = listOf( + BigInteger(-2), // near -1 + -ONE.shl(63), // -2^63: chunk 0 boundary + -ONE.shl(64), // -2^64: exact chunk 0/1 boundary + -ONE.shl(64) - ONE, // -2^64 - 1: just past first chunk boundary + -ONE.shl(127), // -2^127: chunk 1/2 boundary + -ONE.shl(128), // -2^128: exact chunk 2 boundary + -ONE.shl(128) + ONE, // -2^128 + 1 + -ONE.shl(191), // -2^191: chunk 2/3 boundary + -ONE.shl(192), // -2^192: exact chunk 3 boundary + -ONE.shl(192) - ONE, // -2^192 - 1 + // Large negative with mixed chunk values + BigInteger.parseString("-1000000000000000000000000000000000000000"), + ) + for (v in values) { + val result = roundTrip({ it.writeI256(v) }, { it.readI256() }) + assertEquals(v, result, "I256 negative edge case failed for $v") + } + } + + @Test + fun `i256 chunk boundary values`() { + val ONE = BigInteger.ONE + val values = listOf( + ONE.shl(63), // chunk 0 high bit + ONE.shl(64), // chunk 0→1 boundary + ONE.shl(127), // chunk 1 high bit + ONE.shl(128), // chunk 1→2 boundary + ONE.shl(191), // chunk 2 high bit + ONE.shl(192), // chunk 2→3 boundary + ) + for (v in values) { + val result = roundTrip({ it.writeI256(v) }, { it.readI256() }) + assertEquals(v, result, "I256 chunk boundary failed for $v") + } + } + + @Test + fun `u256 round trip`() { + val values = listOf( + BigInteger.ZERO, + BigInteger.ONE, + // U256 max: 2^256 - 1 + BigInteger.parseString("115792089237316195423570985008687907853269984665640564039457584007913129639935"), + ) + for (v in values) { + val result = roundTrip({ it.writeU256(v) }, { it.readU256() }) + assertEquals(v, result, "U256 round-trip failed for $v") + } + } + + @Test + fun `u256 chunk boundary values`() { + val ONE = BigInteger.ONE + val values = listOf( + ONE.shl(63), // chunk 0 high bit (read as negative Long) + ONE.shl(64), // chunk 0→1 boundary + ONE.shl(127), // chunk 1 high bit + ONE.shl(128), // chunk 1→2 boundary + ONE.shl(191), // chunk 2 high bit + ONE.shl(192), // chunk 2→3 boundary + ONE.shl(255), // chunk 3 high bit (read as negative Long) + ) + for (v in values) { + val result = roundTrip({ it.writeU256(v) }, { it.readU256() }) + assertEquals(v, result, "U256 chunk boundary failed for $v") + } + } + + // ---- Overflow detection ---- + + @Test + fun `i128 overflow rejects`() { + val ONE = BigInteger.ONE + val tooLarge = ONE.shl(127) // 2^127 = I128 max + 1 + val tooSmall = -ONE.shl(127) - ONE // -2^127 - 1 + assertFailsWith { + val writer = BsatnWriter() + writer.writeI128(tooLarge) + } + assertFailsWith { + val writer = BsatnWriter() + writer.writeI128(tooSmall) + } + } + + @Test + fun `u128 overflow rejects`() { + val tooLarge = BigInteger.ONE.shl(128) // 2^128 = U128 max + 1 + assertFailsWith { + val writer = BsatnWriter() + writer.writeU128(tooLarge) + } + } + + @Test + fun `u128 negative rejects`() { + assertFailsWith { + val writer = BsatnWriter() + writer.writeU128(BigInteger(-1)) + } + } + + @Test + fun `i256 overflow rejects`() { + val ONE = BigInteger.ONE + val tooLarge = ONE.shl(255) // 2^255 = I256 max + 1 + val tooSmall = -ONE.shl(255) - ONE // -2^255 - 1 + assertFailsWith { + val writer = BsatnWriter() + writer.writeI256(tooLarge) + } + assertFailsWith { + val writer = BsatnWriter() + writer.writeI256(tooSmall) + } + } + + @Test + fun `u256 overflow rejects`() { + val tooLarge = BigInteger.ONE.shl(256) // 2^256 = U256 max + 1 + assertFailsWith { + val writer = BsatnWriter() + writer.writeU256(tooLarge) + } + } + + @Test + fun `u256 negative rejects`() { + assertFailsWith { + val writer = BsatnWriter() + writer.writeU256(BigInteger(-1)) + } + } + + // ---- String ---- + + @Test + fun `string empty`() { + val result = roundTrip({ it.writeString("") }, { it.readString() }) + assertEquals("", result) + } + + @Test + fun `string ascii`() { + val result = roundTrip({ it.writeString("hello world") }, { it.readString() }) + assertEquals("hello world", result) + } + + @Test + fun `string multi byte utf8`() { + val s = "\u00E9\u00F1\u00FC\u2603\uD83D\uDE00" // e-acute, n-tilde, u-umlaut, snowman, emoji + val result = roundTrip({ it.writeString(s) }, { it.readString() }) + assertEquals(s, result) + } + + // ---- ByteArray ---- + + @Test + fun `byte array empty`() { + val result = roundTrip({ it.writeByteArray(byteArrayOf()) }, { it.readByteArray() }) + assertTrue((result as ByteArray).isEmpty()) + } + + @Test + fun `byte array non empty`() { + val input = byteArrayOf(0, 1, 127, -128, -1) + val result = roundTrip({ it.writeByteArray(input) }, { it.readByteArray() }) + assertTrue(input.contentEquals(result as ByteArray)) + } + + // ---- ArrayLen ---- + + @Test + fun `array len round trip`() { + for (v in listOf(0, 1, 1000, Int.MAX_VALUE)) { + val result = roundTrip({ it.writeArrayLen(v) }, { it.readArrayLen() }) + assertEquals(v, result) + } + } + + @Test + fun `array len rejects negative`() { + val writer = BsatnWriter() + assertFailsWith { + writer.writeArrayLen(-1) + } + } + + // ---- Overflow checks ---- + + @Test + fun `read string overflow rejects`() { + // Encode a length that exceeds Int.MAX_VALUE (use UInt.MAX_VALUE = 4294967295) + val writer = BsatnWriter() + writer.writeU32(UInt.MAX_VALUE) // length prefix > Int.MAX_VALUE + val reader = BsatnReader(writer.toByteArray()) + assertFailsWith { + reader.readString() + } + } + + @Test + fun `read byte array overflow rejects`() { + val writer = BsatnWriter() + writer.writeU32(UInt.MAX_VALUE) + val reader = BsatnReader(writer.toByteArray()) + assertFailsWith { + reader.readByteArray() + } + } + + @Test + fun `read array len overflow rejects`() { + val writer = BsatnWriter() + writer.writeU32(UInt.MAX_VALUE) + val reader = BsatnReader(writer.toByteArray()) + assertFailsWith { + reader.readArrayLen() + } + } + + // ---- Reader underflow ---- + + @Test + fun `reader underflow throws`() { + val reader = BsatnReader(byteArrayOf()) + assertFailsWith { + reader.readByte() + } + } + + @Test + fun `reader remaining tracks correctly`() { + val writer = BsatnWriter() + writer.writeI32(42) + writer.writeI32(99) + val reader = BsatnReader(writer.toByteArray()) + assertEquals(8, reader.remaining) + reader.readI32() + assertEquals(4, reader.remaining) + reader.readI32() + assertEquals(0, reader.remaining) + } + + // ---- Writer reset ---- + + @Test + fun `writer reset clears state`() { + val writer = BsatnWriter() + writer.writeI32(42) + assertEquals(4, writer.offset) + writer.reset() + assertEquals(0, writer.offset) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BuilderAndCallbackTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BuilderAndCallbackTest.kt new file mode 100644 index 00000000000..edbb18a213d --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/BuilderAndCallbackTest.kt @@ -0,0 +1,434 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.test.StandardTestDispatcher +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNull +import kotlin.test.assertTrue + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) +class BuilderAndCallbackTest { + + // --- Builder validation --- + + @Test + fun `builder fails without uri`() = runTest { + assertFailsWith { + DbConnection.Builder() + .withDatabaseName("test") + .build() + } + } + + @Test + fun `builder fails without database name`() = runTest { + assertFailsWith { + DbConnection.Builder() + .withUri("ws://localhost:3000") + .build() + } + } + + // --- Builder ensureMinimumVersion --- + + @Test + fun `builder rejects old cli version`() = runTest { + val oldModule = object : ModuleDescriptor { + override val subscribableTableNames = emptyList() + override val cliVersion = "1.0.0" + override fun registerTables(cache: ClientCache) {} + override fun createAccessors(conn: DbConnection) = ModuleAccessors( + object : ModuleTables {}, + object : ModuleReducers {}, + object : ModuleProcedures {}, + ) + override fun handleReducerEvent(conn: DbConnection, ctx: EventContext.Reducer<*>) {} + } + + assertFailsWith { + DbConnection.Builder() + .withUri("ws://localhost:3000") + .withDatabaseName("test") + .withModule(oldModule) + .build() + } + } + + // --- ensureMinimumVersion edge cases --- + + @Test + fun `builder accepts exact minimum version`() = runTest { + val module = object : ModuleDescriptor { + override val subscribableTableNames = emptyList() + override val cliVersion = "2.0.0" + override fun registerTables(cache: ClientCache) {} + override fun createAccessors(conn: DbConnection) = ModuleAccessors( + object : ModuleTables {}, + object : ModuleReducers {}, + object : ModuleProcedures {}, + ) + override fun handleReducerEvent(conn: DbConnection, ctx: EventContext.Reducer<*>) {} + } + + // Should not throw — 2.0.0 is the exact minimum + val conn = buildTestConnection(FakeTransport(), moduleDescriptor = module) + conn.disconnect() + } + + @Test + fun `builder accepts newer version`() = runTest { + val module = object : ModuleDescriptor { + override val subscribableTableNames = emptyList() + override val cliVersion = "3.1.0" + override fun registerTables(cache: ClientCache) {} + override fun createAccessors(conn: DbConnection) = ModuleAccessors( + object : ModuleTables {}, + object : ModuleReducers {}, + object : ModuleProcedures {}, + ) + override fun handleReducerEvent(conn: DbConnection, ctx: EventContext.Reducer<*>) {} + } + + val conn = buildTestConnection(FakeTransport(), moduleDescriptor = module) + conn.disconnect() + } + + @Test + fun `builder accepts pre release suffix`() = runTest { + val module = object : ModuleDescriptor { + override val subscribableTableNames = emptyList() + override val cliVersion = "2.1.0-beta.1" + override fun registerTables(cache: ClientCache) {} + override fun createAccessors(conn: DbConnection) = ModuleAccessors( + object : ModuleTables {}, + object : ModuleReducers {}, + object : ModuleProcedures {}, + ) + override fun handleReducerEvent(conn: DbConnection, ctx: EventContext.Reducer<*>) {} + } + + // Pre-release suffix is stripped; 2.1.0 >= 2.0.0 + val conn = buildTestConnection(FakeTransport(), moduleDescriptor = module) + conn.disconnect() + } + + @Test + fun `builder rejects old minor version`() = runTest { + val module = object : ModuleDescriptor { + override val subscribableTableNames = emptyList() + override val cliVersion = "1.9.9" + override fun registerTables(cache: ClientCache) {} + override fun createAccessors(conn: DbConnection) = ModuleAccessors( + object : ModuleTables {}, + object : ModuleReducers {}, + object : ModuleProcedures {}, + ) + override fun handleReducerEvent(conn: DbConnection, ctx: EventContext.Reducer<*>) {} + } + + assertFailsWith { + DbConnection.Builder() + .withUri("ws://localhost:3000") + .withDatabaseName("test") + .withModule(module) + .build() + } + } + + // --- Module descriptor integration --- + + @Test + fun `db connection constructor does not call register tables`() = runTest { + val transport = FakeTransport() + var tablesRegistered = false + + val descriptor = object : ModuleDescriptor { + override val subscribableTableNames = emptyList() + override val cliVersion = "2.0.0" + override fun registerTables(cache: ClientCache) { + tablesRegistered = true + cache.register("sample", createSampleCache()) + } + override fun createAccessors(conn: DbConnection) = ModuleAccessors( + object : ModuleTables {}, + object : ModuleReducers {}, + object : ModuleProcedures {}, + ) + override fun handleReducerEvent(conn: DbConnection, ctx: EventContext.Reducer<*>) {} + } + + // Use the module descriptor through DbConnection — pass it via the helper + val conn = buildTestConnection(transport, moduleDescriptor = descriptor) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // registerTables is the Builder's responsibility, not DbConnection's + assertFalse(tablesRegistered) + assertNull(conn.clientCache.getUntypedTable("sample")) + conn.disconnect() + } + + // --- handleReducerEvent fires from module descriptor --- + + @Test + fun `module descriptor handle reducer event fires`() = runTest { + val transport = FakeTransport() + var reducerEventName: String? = null + + val descriptor = object : ModuleDescriptor { + override val subscribableTableNames = emptyList() + override val cliVersion = "2.0.0" + override fun registerTables(cache: ClientCache) {} + override fun createAccessors(conn: DbConnection) = ModuleAccessors( + object : ModuleTables {}, + object : ModuleReducers {}, + object : ModuleProcedures {}, + ) + override fun handleReducerEvent(conn: DbConnection, ctx: EventContext.Reducer<*>) { + reducerEventName = ctx.reducerName + } + } + + val conn = buildTestConnection(transport, moduleDescriptor = descriptor) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.callReducer("myReducer", byteArrayOf(), "args", callback = null) + advanceUntilIdle() + + val sent = transport.sentMessages.filterIsInstance().last() + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = sent.requestId, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + advanceUntilIdle() + + assertEquals("myReducer", reducerEventName) + conn.disconnect() + } + + // --- Callback removal --- + + @Test + fun `remove on disconnect prevents callback`() = runTest { + val transport = FakeTransport() + var fired = false + val cb: (DbConnectionView, Throwable?) -> Unit = { _, _ -> fired = true } + + val conn = createTestConnection(transport, onDisconnect = cb) + conn.removeOnDisconnect(cb) + conn.connect() + + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + transport.closeFromServer() + advanceUntilIdle() + + assertFalse(fired) + conn.disconnect() + } + + // --- removeOnConnectError --- + + @Test + fun `remove on connect error prevents callback`() = runTest { + val transport = FakeTransport(connectError = RuntimeException("fail")) + var fired = false + val cb: (DbConnectionView, Throwable) -> Unit = { _, _ -> fired = true } + + val conn = createTestConnection(transport, onConnectError = cb) + conn.removeOnConnectError(cb) + + try { + conn.connect() + } catch (_: Exception) { } + advanceUntilIdle() + + assertFalse(fired) + conn.disconnect() + } + + // --- Multiple callbacks --- + + @Test + fun `multiple on connect callbacks all fire`() = runTest { + val transport = FakeTransport() + var count = 0 + val cb: (DbConnectionView, Identity, String) -> Unit = { _, _, _ -> count++ } + val conn = DbConnection( + transport = transport, + scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)), + onConnectCallbacks = listOf(cb, cb, cb), + onDisconnectCallbacks = emptyList(), + onConnectErrorCallbacks = emptyList(), + clientConnectionId = ConnectionId.random(), + stats = Stats(), + moduleDescriptor = null, + callbackDispatcher = null, + ) + conn.connect() + + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + assertEquals(3, count) + conn.disconnect() + } + + // --- User callback exception does not crash receive loop --- + + @Test + fun `user callback exception does not crash connection`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Register a callback that throws + cache.onInsert { _, _ -> error("callback explosion") } + + val row = SampleRow(1, "Alice") + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row.encode())))), + ) + ) + advanceUntilIdle() + + // Row should still be inserted despite callback exception + assertEquals(1, cache.count()) + // Connection should still be active + assertTrue(conn.isActive) + conn.disconnect() + } + + // --- Callback exception handling --- + + @Test + fun `on connect callback exception does not prevent other callbacks`() = runTest { + val transport = FakeTransport() + var secondFired = false + val conn = DbConnection( + transport = transport, + scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)), + onConnectCallbacks = listOf( + { _, _, _ -> error("onConnect explosion") }, + { _, _, _ -> secondFired = true }, + ), + onDisconnectCallbacks = emptyList(), + onConnectErrorCallbacks = emptyList(), + clientConnectionId = ConnectionId.random(), + stats = Stats(), + moduleDescriptor = null, + callbackDispatcher = null, + ) + conn.connect() + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + assertTrue(secondFired, "Second onConnect callback should fire despite first throwing") + assertTrue(conn.isActive) + conn.disconnect() + } + + @Test + fun `on delete callback exception does not prevent row removal`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Insert a row first + val row = SampleRow(1, "Alice") + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row.encode())))), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) + + // Register a throwing onDelete callback + cache.onDelete { _, _ -> error("delete callback explosion") } + + // Delete the row via transaction update + transport.sendToClient( + ServerMessage.TransactionUpdateMsg( + update = TransactionUpdate( + listOf( + QuerySetUpdate( + handle.querySetId, + listOf( + TableUpdate( + "sample", + listOf(TableUpdateRows.PersistentTable( + inserts = buildRowList(), + deletes = buildRowList(row.encode()), + )) + ) + ), + ) + ) + ) + ) + ) + advanceUntilIdle() + + // Row should still be deleted despite callback exception + assertEquals(0, cache.count()) + assertTrue(conn.isActive) + conn.disconnect() + } + + @Test + fun `reducer callback exception does not crash connection`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val requestId = conn.callReducer( + reducerName = "boom", + encodedArgs = byteArrayOf(), + typedArgs = "args", + callback = { _ -> error("reducer callback explosion") }, + ) + advanceUntilIdle() + + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = requestId, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + advanceUntilIdle() + + assertTrue(conn.isActive, "Connection should survive throwing reducer callback") + conn.disconnect() + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/CacheOperationsEdgeCaseTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/CacheOperationsEdgeCaseTest.kt new file mode 100644 index 00000000000..9cb93793628 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/CacheOperationsEdgeCaseTest.kt @@ -0,0 +1,350 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotEquals +import kotlin.test.assertNull +import kotlin.test.assertSame +import kotlin.test.assertTrue + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) +class CacheOperationsEdgeCaseTest { + + // ========================================================================= + // Cache Operations Edge Cases + // ========================================================================= + + @Test + fun `clear fires internal delete listeners for all rows`() { + val cache = createSampleCache() + val deletedRows = mutableListOf() + cache.addInternalDeleteListener { deletedRows.add(it) } + + val row1 = SampleRow(1, "Alice") + val row2 = SampleRow(2, "Bob") + cache.applyInserts(STUB_CTX, buildRowList(row1.encode(), row2.encode())) + + cache.clear() + + assertEquals(0, cache.count()) + assertEquals(2, deletedRows.size) + assertTrue(deletedRows.containsAll(listOf(row1, row2))) + } + + @Test + fun `clear on empty cache is no op`() { + val cache = createSampleCache() + var listenerFired = false + cache.addInternalDeleteListener { listenerFired = true } + + cache.clear() + assertFalse(listenerFired) + } + + @Test + fun `delete nonexistent row is no op`() { + val cache = createSampleCache() + val row = SampleRow(99, "Ghost") + + var deleteFired = false + cache.onDelete { _, _ -> deleteFired = true } + + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed) + + assertFalse(deleteFired) + assertEquals(0, cache.count()) + } + + @Test + fun `insert empty row list is no op`() { + val cache = createSampleCache() + var insertFired = false + cache.onInsert { _, _ -> insertFired = true } + + val callbacks = cache.applyInserts(STUB_CTX, buildRowList()) + + assertEquals(0, cache.count()) + assertTrue(callbacks.isEmpty()) + assertFalse(insertFired) + } + + @Test + fun `remove callback prevents it from firing`() { + val cache = createSampleCache() + var fired = false + val cb: (EventContext, SampleRow) -> Unit = { _, _ -> fired = true } + + cache.onInsert(cb) + cache.removeOnInsert(cb) + + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(1, "Alice").encode())) + // Invoke any pending callbacks + // No PendingCallbacks should exist for this insert since we removed the callback + + assertFalse(fired) + } + + @Test + fun `internal listeners fired on insert after cas`() { + val cache = createSampleCache() + val internalInserts = mutableListOf() + cache.addInternalInsertListener { internalInserts.add(it) } + + val row = SampleRow(1, "Alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + assertEquals(listOf(row), internalInserts) + } + + @Test + fun `internal listeners fired on delete after cas`() { + val cache = createSampleCache() + val internalDeletes = mutableListOf() + cache.addInternalDeleteListener { internalDeletes.add(it) } + + val row = SampleRow(1, "Alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed) + + assertEquals(listOf(row), internalDeletes) + } + + @Test + fun `internal listeners fired on update for both old and new`() { + val cache = createSampleCache() + val internalInserts = mutableListOf() + val internalDeletes = mutableListOf() + cache.addInternalInsertListener { internalInserts.add(it) } + cache.addInternalDeleteListener { internalDeletes.add(it) } + + val oldRow = SampleRow(1, "Old") + cache.applyInserts(STUB_CTX, buildRowList(oldRow.encode())) + internalInserts.clear() // Reset from the initial insert + + val newRow = SampleRow(1, "New") + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(newRow.encode()), + deletes = buildRowList(oldRow.encode()), + ) + val parsed = cache.parseUpdate(update) + cache.applyUpdate(STUB_CTX, parsed) + + // On update, old row fires delete listener, new row fires insert listener + assertEquals(listOf(oldRow), internalDeletes) + assertEquals(listOf(newRow), internalInserts) + } + + @Test + fun `batch insert multiple rows fires callbacks for each`() { + val cache = createSampleCache() + val inserted = mutableListOf() + cache.onInsert { _, row -> inserted.add(row) } + + val rows = (1..5).map { SampleRow(it, "Row$it") } + val callbacks = cache.applyInserts( + STUB_CTX, + buildRowList(*rows.map { it.encode() }.toTypedArray()) + ) + for (cb in callbacks) cb.invoke() + + assertEquals(5, cache.count()) + assertEquals(rows, inserted) + } + + // ========================================================================= + // ClientCache Registry + // ========================================================================= + + @Test + fun `client cache get table throws for unknown table`() { + val cc = ClientCache() + assertFailsWith { + cc.getTable("nonexistent") + } + } + + @Test + fun `client cache get table or null returns null`() { + val cc = ClientCache() + assertNull(cc.getTableOrNull("nonexistent")) + } + + @Test + fun `client cache get or create table creates once`() { + val cc = ClientCache() + var factoryCalls = 0 + + val cache1 = cc.getOrCreateTable("t") { + factoryCalls++ + createSampleCache() + } + val cache2 = cc.getOrCreateTable("t") { + factoryCalls++ + createSampleCache() + } + + assertEquals(1, factoryCalls) + assertSame(cache1, cache2) + } + + @Test + fun `client cache table names`() { + val cc = ClientCache() + cc.register("alpha", createSampleCache()) + cc.register("beta", createSampleCache()) + + assertEquals(setOf("alpha", "beta"), cc.tableNames()) + } + + @Test + fun `client cache clear clears all tables`() { + val cc = ClientCache() + val cacheA = createSampleCache() + val cacheB = createSampleCache() + cc.register("a", cacheA) + cc.register("b", cacheB) + + cacheA.applyInserts(STUB_CTX, buildRowList(SampleRow(1, "X").encode())) + cacheB.applyInserts(STUB_CTX, buildRowList(SampleRow(2, "Y").encode())) + + cc.clear() + + assertEquals(0, cacheA.count()) + assertEquals(0, cacheB.count()) + } + + // ========================================================================= + // Ref Count Edge Cases + // ========================================================================= + + @Test + fun `ref count survives update on multi ref row`() { + val cache = createSampleCache() + val row = SampleRow(1, "Alice") + + // Insert twice — refCount = 2 + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + assertEquals(1, cache.count()) + + // Update the row — should preserve refCount + val updatedRow = SampleRow(1, "Alice Updated") + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(updatedRow.encode()), + deletes = buildRowList(row.encode()), + ) + val parsed = cache.parseUpdate(update) + cache.applyUpdate(STUB_CTX, parsed) + + assertEquals(1, cache.count()) + assertEquals("Alice Updated", cache.all().single().name) + + // Deleting once should still keep the row (refCount was 2, update preserves it) + val parsedDelete = cache.parseDeletes(buildRowList(updatedRow.encode())) + cache.applyDeletes(STUB_CTX, parsedDelete) + // The refCount was preserved during update, so after one delete it should still be there + assertEquals(1, cache.count()) + } + + @Test + fun `delete with high ref count only decrements`() { + val cache = createSampleCache() + val row = SampleRow(1, "Alice") + + // Insert 3 times — refCount = 3 + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + var deleteFired = false + cache.onDelete { _, _ -> deleteFired = true } + + // Delete once — refCount goes to 2 + val parsed1 = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed1) + assertEquals(1, cache.count()) + assertFalse(deleteFired) + + // Delete again — refCount goes to 1 + val parsed2 = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed2) + assertEquals(1, cache.count()) + assertFalse(deleteFired) + + // Delete final — refCount goes to 0 + val parsed3 = cache.parseDeletes(buildRowList(row.encode())) + val callbacks = cache.applyDeletes(STUB_CTX, parsed3) + for (cb in callbacks) cb.invoke() + assertEquals(0, cache.count()) + assertTrue(deleteFired) + } + + // ========================================================================= + // BsatnRowKey equality and hashCode + // ========================================================================= + + @Test + fun `bsatn row key equality and hash code`() { + val a = BsatnRowKey(byteArrayOf(1, 2, 3)) + val b = BsatnRowKey(byteArrayOf(1, 2, 3)) + val c = BsatnRowKey(byteArrayOf(1, 2, 4)) + + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + assertNotEquals(a, c) + } + + @Test + fun `bsatn row key works as map key`() { + val map = mutableMapOf() + val key1 = BsatnRowKey(byteArrayOf(10, 20)) + val key2 = BsatnRowKey(byteArrayOf(10, 20)) + val key3 = BsatnRowKey(byteArrayOf(30, 40)) + + map[key1] = "first" + map[key2] = "second" // Same content as key1, should overwrite + map[key3] = "third" + + assertEquals(2, map.size) + assertEquals("second", map[key1]) + assertEquals("third", map[key3]) + } + + // ========================================================================= + // DecodedRow equality + // ========================================================================= + + @Test + fun `decoded row equality`() { + val row1 = DecodedRow(SampleRow(1, "A"), byteArrayOf(1, 2, 3)) + val row2 = DecodedRow(SampleRow(1, "A"), byteArrayOf(1, 2, 3)) + val row3 = DecodedRow(SampleRow(1, "A"), byteArrayOf(4, 5, 6)) + + assertEquals(row1, row2) + assertEquals(row1.hashCode(), row2.hashCode()) + assertNotEquals(row1, row3) + } + + // ========================================================================= + // FixedSize hint validation + // ========================================================================= + + @Test + fun `fixed size hint non divisible rows data throws`() { + val cache = createSampleCache() + // 7 bytes of data with FixedSize(4) → 7 % 4 != 0 + val rowList = BsatnRowList( + sizeHint = RowSizeHint.FixedSize(4u), + rowsData = ByteArray(7), + ) + assertFailsWith("Should reject non-divisible FixedSize row data") { + cache.decodeRowList(rowList) + } + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/CallbackOrderingTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/CallbackOrderingTest.kt new file mode 100644 index 00000000000..740dcd0c110 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/CallbackOrderingTest.kt @@ -0,0 +1,366 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QueryRows +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QuerySetUpdate +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ServerMessage +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.SingleTableRows +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.TableUpdate +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.TableUpdateRows +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.TransactionUpdate +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import kotlinx.coroutines.CoroutineExceptionHandler +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) +class CallbackOrderingTest { + + // ========================================================================= + // Callback Ordering Guarantees + // ========================================================================= + + @Test + fun `pre apply delete fires before apply delete across tables`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + + val cacheA = createSampleCache() + val cacheB = createSampleCache() + conn.clientCache.register("table_a", cacheA) + conn.clientCache.register("table_b", cacheB) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val rowA = SampleRow(1, "A") + val rowB = SampleRow(2, "B") + val handle = conn.subscribe(listOf("SELECT * FROM table_a", "SELECT * FROM table_b")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows( + listOf( + SingleTableRows("table_a", buildRowList(rowA.encode())), + SingleTableRows("table_b", buildRowList(rowB.encode())), + ) + ), + ) + ) + advanceUntilIdle() + assertEquals(1, cacheA.count()) + assertEquals(1, cacheB.count()) + + // Track ordering: onBeforeDelete should fire for BOTH tables + // BEFORE any onDelete fires + val events = mutableListOf() + cacheA.onBeforeDelete { _, _ -> events.add("beforeDelete_A") } + cacheB.onBeforeDelete { _, _ -> events.add("beforeDelete_B") } + cacheA.onDelete { _, _ -> events.add("delete_A") } + cacheB.onDelete { _, _ -> events.add("delete_B") } + + // Transaction deleting from both tables + transport.sendToClient( + ServerMessage.TransactionUpdateMsg( + TransactionUpdate( + listOf( + QuerySetUpdate( + handle.querySetId, + listOf( + TableUpdate( + "table_a", + listOf( + TableUpdateRows.PersistentTable( + inserts = buildRowList(), + deletes = buildRowList(rowA.encode()), + ) + ) + ), + TableUpdate( + "table_b", + listOf( + TableUpdateRows.PersistentTable( + inserts = buildRowList(), + deletes = buildRowList(rowB.encode()), + ) + ) + ), + ) + ) + ) + ) + ) + ) + advanceUntilIdle() + + // All beforeDeletes must come before any delete + val beforeDeleteIndices = events.indices.filter { events[it].startsWith("beforeDelete") } + val deleteIndices = events.indices.filter { events[it].startsWith("delete_") } + assertTrue(beforeDeleteIndices.isNotEmpty()) + assertTrue(deleteIndices.isNotEmpty()) + assertTrue(beforeDeleteIndices.max() < deleteIndices.min()) + + conn.disconnect() + } + + @Test + fun `update does not fire on before delete for updated row`() { + val cache = createSampleCache() + val oldRow = SampleRow(1, "Alice") + cache.applyInserts(STUB_CTX, buildRowList(oldRow.encode())) + + val beforeDeleteRows = mutableListOf() + cache.onBeforeDelete { _, row -> beforeDeleteRows.add(row) } + + // Update (same key in both inserts and deletes) should NOT fire onBeforeDelete + val newRow = SampleRow(1, "Alice Updated") + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(newRow.encode()), + deletes = buildRowList(oldRow.encode()), + ) + val parsed = cache.parseUpdate(update) + cache.preApplyUpdate(STUB_CTX, parsed) + cache.applyUpdate(STUB_CTX, parsed) + + assertTrue(beforeDeleteRows.isEmpty(), "onBeforeDelete should NOT fire for updates") + } + + @Test + fun `pure delete fires on before delete`() { + val cache = createSampleCache() + val row = SampleRow(1, "Alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val beforeDeleteRows = mutableListOf() + cache.onBeforeDelete { _, r -> beforeDeleteRows.add(r) } + + // Pure delete (no corresponding insert) should fire onBeforeDelete + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(), + deletes = buildRowList(row.encode()), + ) + val parsed = cache.parseUpdate(update) + cache.preApplyUpdate(STUB_CTX, parsed) + + assertEquals(listOf(row), beforeDeleteRows) + } + + @Test + fun `callback firing order insert update delete`() { + val cache = createSampleCache() + + // Pre-populate + val existingRow = SampleRow(1, "Old") + val toDelete = SampleRow(2, "Delete Me") + cache.applyInserts(STUB_CTX, buildRowList(existingRow.encode(), toDelete.encode())) + + val events = mutableListOf() + cache.onInsert { _, row -> events.add("insert:${row.name}") } + cache.onUpdate { _, old, new -> events.add("update:${old.name}->${new.name}") } + cache.onDelete { _, row -> events.add("delete:${row.name}") } + cache.onBeforeDelete { _, row -> events.add("beforeDelete:${row.name}") } + + // Transaction: update row1, delete row2, insert row3 + val updatedRow = SampleRow(1, "New") + val newRow = SampleRow(3, "Fresh") + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(updatedRow.encode(), newRow.encode()), + deletes = buildRowList(existingRow.encode(), toDelete.encode()), + ) + val parsed = cache.parseUpdate(update) + + // Pre-apply phase + cache.preApplyUpdate(STUB_CTX, parsed) + + // Only pure deletes get onBeforeDelete (not updates) + assertEquals(listOf("beforeDelete:Delete Me"), events) + + // Apply phase + events.clear() + val callbacks = cache.applyUpdate(STUB_CTX, parsed) + for (cb in callbacks) cb.invoke() + + // Must contain all events in the correct order: + // updates and inserts fire first (from the insert processing loop), + // then pure deletes (from the remaining-deletes loop). + assertEquals( + listOf("update:Old->New", "insert:Fresh", "delete:Delete Me"), + events, + ) + } + + // ========================================================================= + // Callback Exception Resilience + // ========================================================================= + + @Test + fun `on connect exception does not prevent subsequent messages`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, onConnect = { _, _, _ -> + error("connect callback explosion") + }, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Connection should still work despite callback exception + assertTrue(conn.isActive) + + val handle = conn.subscribe(listOf("SELECT * FROM t")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + // The subscribe was sent and the SubscribeApplied was processed + assertTrue(handle.isActive) + conn.disconnect() + } + + @Test + fun `on before delete exception does not prevent mutation`() { + val cache = createSampleCache() + val row = SampleRow(1, "Alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + cache.onBeforeDelete { _, _ -> error("boom in beforeDelete") } + + // The preApply phase will throw, but let's verify the apply phase + // still works independently (since the exception is in user code, + // it's caught by runUserCallback in DbConnection context) + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(), + deletes = buildRowList(row.encode()), + ) + val parsed = cache.parseUpdate(update) + // preApplyUpdate will throw since we're not wrapped in runUserCallback + // This tests that if it does throw, the cache is still consistent + try { + cache.preApplyUpdate(STUB_CTX, parsed) + } catch (_: Exception) { + // Expected + } + + // applyUpdate should still work + cache.applyUpdate(STUB_CTX, parsed) + assertEquals(0, cache.count()) + } + + // ========================================================================= + // EventContext Correctness + // ========================================================================= + + @Test + fun `subscribe applied context type`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var capturedCtx: EventContext? = null + cache.onInsert { ctx, _ -> capturedCtx = ctx } + + val row = SampleRow(1, "Alice") + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row.encode())))), + ) + ) + advanceUntilIdle() + + assertTrue(capturedCtx is EventContext.SubscribeApplied) + conn.disconnect() + } + + @Test + fun `transaction update context type`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + + var capturedCtx: EventContext? = null + cache.onInsert { ctx, _ -> capturedCtx = ctx } + + transport.sendToClient( + transactionUpdateMsg( + handle.querySetId, + "sample", + inserts = buildRowList(SampleRow(1, "Alice").encode()), + ) + ) + advanceUntilIdle() + + assertTrue(capturedCtx is EventContext.Transaction) + conn.disconnect() + } + + // ========================================================================= + // onDisconnect callback edge cases + // ========================================================================= + + @Test + fun `on disconnect added after build still fires`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Add callback AFTER connection is established + var fired = false + conn.onDisconnect { _, _ -> fired = true } + + conn.disconnect() + advanceUntilIdle() + + assertTrue(fired) + } + + @Test + fun `on connect error added after build still fires`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + + // Add callback AFTER connection is established + var fired = false + conn.onConnectError { _, _ -> fired = true } + + // Trigger identity mismatch (which fires onConnectError) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val differentIdentity = Identity(BigInteger.TEN) + transport.sendToClient( + ServerMessage.InitialConnection( + identity = differentIdentity, + connectionId = TEST_CONNECTION_ID, + token = TEST_TOKEN, + ) + ) + advanceUntilIdle() + + assertTrue(fired) + // Connection auto-closes on identity mismatch (no manual disconnect needed) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ClientMessageTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ClientMessageTest.kt new file mode 100644 index 00000000000..3252d0d741d --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ClientMessageTest.kt @@ -0,0 +1,169 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ClientMessage +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QuerySetId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.UnsubscribeFlags +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class ClientMessageTest { + + // ---- Subscribe (tag 0) ---- + + @Test + fun `subscribe encodes correctly`() { + val msg = ClientMessage.Subscribe( + requestId = 42u, + querySetId = QuerySetId(7u), + queryStrings = listOf("SELECT * FROM Players", "SELECT * FROM Items"), + ) + val bytes = ClientMessage.encodeToBytes(msg) + val reader = BsatnReader(bytes) + + assertEquals(0, reader.readSumTag().toInt(), "tag") + assertEquals(42u, reader.readU32(), "requestId") + assertEquals(7u, reader.readU32(), "querySetId") + assertEquals(2, reader.readArrayLen(), "query count") + assertEquals("SELECT * FROM Players", reader.readString()) + assertEquals("SELECT * FROM Items", reader.readString()) + assertEquals(0, reader.remaining) + } + + @Test + fun `subscribe empty queries`() { + val msg = ClientMessage.Subscribe( + requestId = 0u, + querySetId = QuerySetId(0u), + queryStrings = emptyList(), + ) + val bytes = ClientMessage.encodeToBytes(msg) + val reader = BsatnReader(bytes) + + assertEquals(0, reader.readSumTag().toInt()) + assertEquals(0u, reader.readU32()) + assertEquals(0u, reader.readU32()) + assertEquals(0, reader.readArrayLen()) + assertEquals(0, reader.remaining) + } + + // ---- Unsubscribe (tag 1) ---- + + @Test + fun `unsubscribe default flags`() { + val msg = ClientMessage.Unsubscribe( + requestId = 10u, + querySetId = QuerySetId(5u), + flags = UnsubscribeFlags.Default, + ) + val bytes = ClientMessage.encodeToBytes(msg) + val reader = BsatnReader(bytes) + + assertEquals(1, reader.readSumTag().toInt(), "tag") + assertEquals(10u, reader.readU32(), "requestId") + assertEquals(5u, reader.readU32(), "querySetId") + assertEquals(0, reader.readSumTag().toInt(), "flags = Default") + assertEquals(0, reader.remaining) + } + + @Test + fun `unsubscribe send dropped rows flags`() { + val msg = ClientMessage.Unsubscribe( + requestId = 10u, + querySetId = QuerySetId(5u), + flags = UnsubscribeFlags.SendDroppedRows, + ) + val bytes = ClientMessage.encodeToBytes(msg) + val reader = BsatnReader(bytes) + + assertEquals(1, reader.readSumTag().toInt()) + assertEquals(10u, reader.readU32()) + assertEquals(5u, reader.readU32()) + assertEquals(1, reader.readSumTag().toInt(), "flags = SendDroppedRows") + assertEquals(0, reader.remaining) + } + + // ---- OneOffQuery (tag 2) ---- + + @Test + fun `one off query encodes correctly`() { + val msg = ClientMessage.OneOffQuery( + requestId = 99u, + queryString = "SELECT * FROM Players WHERE id = 1", + ) + val bytes = ClientMessage.encodeToBytes(msg) + val reader = BsatnReader(bytes) + + assertEquals(2, reader.readSumTag().toInt(), "tag") + assertEquals(99u, reader.readU32(), "requestId") + assertEquals("SELECT * FROM Players WHERE id = 1", reader.readString()) + assertEquals(0, reader.remaining) + } + + // ---- CallReducer (tag 3) ---- + + @Test + fun `call reducer encodes correctly`() { + val args = byteArrayOf(1, 2, 3, 4) + val msg = ClientMessage.CallReducer( + requestId = 7u, + flags = 0u, + reducer = "add_player", + args = args, + ) + val bytes = ClientMessage.encodeToBytes(msg) + val reader = BsatnReader(bytes) + + assertEquals(3, reader.readSumTag().toInt(), "tag") + assertEquals(7u, reader.readU32(), "requestId") + assertEquals(0u.toUByte(), reader.readU8(), "flags") + assertEquals("add_player", reader.readString(), "reducer") + assertTrue(args.contentEquals(reader.readByteArray()), "args") + assertEquals(0, reader.remaining) + } + + @Test + fun `call reducer equality`() { + val msg1 = ClientMessage.CallReducer(1u, 0u, "test", byteArrayOf(1, 2, 3)) + val msg2 = ClientMessage.CallReducer(1u, 0u, "test", byteArrayOf(1, 2, 3)) + val msg3 = ClientMessage.CallReducer(1u, 0u, "test", byteArrayOf(4, 5, 6)) + + assertEquals(msg1, msg2) + assertEquals(msg1.hashCode(), msg2.hashCode()) + assertTrue(msg1 != msg3) + } + + // ---- CallProcedure (tag 4) ---- + + @Test + fun `call procedure encodes correctly`() { + val args = byteArrayOf(10, 20) + val msg = ClientMessage.CallProcedure( + requestId = 3u, + flags = 1u, + procedure = "get_player_stats", + args = args, + ) + val bytes = ClientMessage.encodeToBytes(msg) + val reader = BsatnReader(bytes) + + assertEquals(4, reader.readSumTag().toInt(), "tag") + assertEquals(3u, reader.readU32(), "requestId") + assertEquals(1u.toUByte(), reader.readU8(), "flags") + assertEquals("get_player_stats", reader.readString(), "procedure") + assertTrue(args.contentEquals(reader.readByteArray()), "args") + assertEquals(0, reader.remaining) + } + + @Test + fun `call procedure equality`() { + val msg1 = ClientMessage.CallProcedure(1u, 0u, "proc", byteArrayOf(1)) + val msg2 = ClientMessage.CallProcedure(1u, 0u, "proc", byteArrayOf(1)) + val msg3 = ClientMessage.CallProcedure(1u, 0u, "proc", byteArrayOf(2)) + + assertEquals(msg1, msg2) + assertEquals(msg1.hashCode(), msg2.hashCode()) + assertTrue(msg1 != msg3) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ConnectionLifecycleTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ConnectionLifecycleTest.kt new file mode 100644 index 00000000000..6f4747c0101 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ConnectionLifecycleTest.kt @@ -0,0 +1,358 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) +class ConnectionLifecycleTest { + + // --- Connection lifecycle --- + + @Test + fun `on connect fires after initial connection`() = runTest { + val transport = FakeTransport() + var connectIdentity: Identity? = null + var connectToken: String? = null + + val conn = buildTestConnection(transport, onConnect = { _, id, tok -> + connectIdentity = id + connectToken = tok + }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + assertEquals(TEST_IDENTITY, connectIdentity) + assertEquals(TEST_TOKEN, connectToken) + conn.disconnect() + } + + @Test + fun `identity and token set after connect`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + + assertNull(conn.identity) + assertNull(conn.token) + + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + assertEquals(TEST_IDENTITY, conn.identity) + assertEquals(TEST_TOKEN, conn.token) + assertEquals(TEST_CONNECTION_ID, conn.connectionId) + conn.disconnect() + } + + @Test + fun `on disconnect fires on server close`() = runTest { + val transport = FakeTransport() + var disconnected = false + var disconnectError: Throwable? = null + + val conn = buildTestConnection(transport, onDisconnect = { _, err -> + disconnected = true + disconnectError = err + }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + transport.closeFromServer() + advanceUntilIdle() + + assertTrue(disconnected) + assertNull(disconnectError) + conn.disconnect() + } + + // --- onConnectError --- + + @Test + fun `on connect error fires when transport fails`() = runTest { + val error = RuntimeException("connection refused") + val transport = FakeTransport(connectError = error) + var capturedError: Throwable? = null + + val conn = createTestConnection(transport, onConnectError = { _, err -> + capturedError = err + }) + conn.connect() + + assertEquals(error, capturedError) + assertFalse(conn.isActive) + } + + // --- Identity mismatch --- + + @Test + fun `identity mismatch fires on connect error and disconnects`() = runTest { + val transport = FakeTransport() + var errorMsg: String? = null + var disconnectReason: Throwable? = null + var disconnected = false + val conn = buildTestConnection( + transport, + onConnectError = { _, err -> errorMsg = err.message }, + onDisconnect = { _, reason -> + disconnected = true + disconnectReason = reason + }, + ) + + // First InitialConnection sets identity + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + assertEquals(TEST_IDENTITY, conn.identity) + + // Second InitialConnection with different identity triggers error and disconnect + val differentIdentity = Identity(BigInteger.TEN) + transport.sendToClient( + ServerMessage.InitialConnection( + identity = differentIdentity, + connectionId = TEST_CONNECTION_ID, + token = TEST_TOKEN, + ) + ) + advanceUntilIdle() + + // onConnectError fired + assertNotNull(errorMsg) + assertTrue(errorMsg!!.contains("unexpected identity")) + // Identity should NOT have changed + assertEquals(TEST_IDENTITY, conn.identity) + // Connection should have transitioned to CLOSED (not left in CONNECTED) + assertTrue(disconnected, "onDisconnect should have fired") + assertNotNull(disconnectReason, "disconnect reason should be the identity mismatch error") + assertTrue(disconnectReason!!.message!!.contains("unexpected identity")) + } + + // --- close() --- + + @Test + fun `close fires on disconnect`() = runTest { + val transport = FakeTransport() + var disconnected = false + val conn = buildTestConnection(transport, onDisconnect = { _, _ -> + disconnected = true + }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.disconnect() + advanceUntilIdle() + + assertTrue(disconnected) + } + + // --- disconnect() states --- + + @Test + fun `disconnect when already disconnected is no op`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.disconnect() + advanceUntilIdle() + // Second disconnect should not throw + conn.disconnect() + } + + // --- close() from never-connected state --- + + @Test + fun `close from never connected state`() = runTest { + val transport = FakeTransport() + val conn = createTestConnection(transport) + // close() on a freshly created connection that was never connected should not throw + conn.disconnect() + } + + // --- use {} block --- + + @Test + fun `use block disconnects on normal return`() = runTest { + val transport = FakeTransport() + var disconnected = false + val conn = buildTestConnection(transport, onDisconnect = { _, _ -> disconnected = true }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.use { /* no-op */ } + advanceUntilIdle() + + assertTrue(disconnected) + assertFalse(conn.isActive) + } + + @Test + fun `use block disconnects on exception`() = runTest { + val transport = FakeTransport() + var disconnected = false + val conn = buildTestConnection(transport, onDisconnect = { _, _ -> disconnected = true }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + assertFailsWith { + conn.use { throw IllegalStateException("boom") } + } + advanceUntilIdle() + + assertTrue(disconnected) + assertFalse(conn.isActive) + } + + @Test + fun `use block returns value`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val result = conn.use { 42 } + + assertEquals(42, result) + } + + @Test + fun `use block disconnects on cancellation`() = runTest { + val transport = FakeTransport() + var disconnected = false + val conn = buildTestConnection(transport, onDisconnect = { _, _ -> disconnected = true }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val job = launch { + conn.use { kotlinx.coroutines.awaitCancellation() } + } + advanceUntilIdle() + + job.cancel() + advanceUntilIdle() + + assertTrue(disconnected) + } + + // --- Token not overwritten if already set --- + + @Test + fun `token not overwritten on second initial connection`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + + // First connection sets token + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + assertEquals(TEST_TOKEN, conn.token) + + // Second InitialConnection with same identity but different token — token stays + transport.sendToClient( + ServerMessage.InitialConnection( + identity = TEST_IDENTITY, + connectionId = TEST_CONNECTION_ID, + token = "new-token", + ) + ) + advanceUntilIdle() + + assertEquals(TEST_TOKEN, conn.token) + conn.disconnect() + } + + // --- sendMessage after close --- + + @Test + fun `subscribe after close does not crash`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.disconnect() + advanceUntilIdle() + + // Calling subscribe on a closed connection is a graceful no-op + // (logs warning, does not throw) + conn.subscribe(listOf("SELECT * FROM player")) + } + + // --- Disconnect race conditions --- + + @Test + fun `disconnect during server close does not double fire callbacks`() = runTest { + val transport = FakeTransport() + var disconnectCount = 0 + val conn = buildTestConnection(transport, onDisconnect = { _, _ -> + disconnectCount++ + }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Close from server side and call disconnect concurrently + transport.closeFromServer() + conn.disconnect() + advanceUntilIdle() + + assertEquals(1, disconnectCount, "onDisconnect should fire exactly once") + } + + @Test + fun `disconnect passes reason to callbacks`() = runTest { + val transport = FakeTransport() + var receivedError: Throwable? = null + val conn = buildTestConnection(transport, onDisconnect = { _, err -> + receivedError = err + }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val reason = RuntimeException("forced disconnect") + conn.disconnect(reason) + advanceUntilIdle() + + assertEquals(reason, receivedError) + } + + // --- SubscriptionError with null requestId triggers disconnect --- + + @Test + fun `subscription error with null request id disconnects`() = runTest { + val transport = FakeTransport() + var disconnected = false + val conn = buildTestConnection(transport, onDisconnect = { _, _ -> + disconnected = true + }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var errorMsg: String? = null + val handle = conn.subscribe( + queries = listOf("SELECT * FROM player"), + onError = listOf { _, err -> errorMsg = (err as SubscriptionError.ServerError).message }, + ) + + transport.sendToClient( + ServerMessage.SubscriptionError( + requestId = null, + querySetId = handle.querySetId, + error = "fatal subscription error", + ) + ) + advanceUntilIdle() + + assertEquals("fatal subscription error", errorMsg) + assertTrue(handle.isEnded) + assertTrue(disconnected) + conn.disconnect() + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ConnectionStateTransitionTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ConnectionStateTransitionTest.kt new file mode 100644 index 00000000000..4402a025963 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ConnectionStateTransitionTest.kt @@ -0,0 +1,326 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import kotlinx.coroutines.CoroutineExceptionHandler +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) +class ConnectionStateTransitionTest { + + // ========================================================================= + // Connection State Transitions + // ========================================================================= + + @Test + fun `connection state progression`() = runTest { + val transport = FakeTransport() + val conn = createTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + + // Initial state — not active + assertFalse(conn.isActive) + + // After connect() — active + conn.connect() + assertTrue(conn.isActive) + + // After disconnect() — not active + conn.disconnect() + advanceUntilIdle() + assertFalse(conn.isActive) + } + + @Test + fun `connect after disconnect throws`() = runTest { + val transport = FakeTransport() + val conn = createTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + conn.connect() + conn.disconnect() + advanceUntilIdle() + + // CLOSED is terminal — cannot reconnect + assertFailsWith { + conn.connect() + } + } + + @Test + fun `double connect throws`() = runTest { + val transport = FakeTransport() + val conn = createTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + conn.connect() + + // Already CONNECTED — second connect should fail + assertFailsWith { + conn.connect() + } + conn.disconnect() + } + + @Test + fun `connect failure renders connection inactive`() = runTest { + val error = RuntimeException("connection refused") + val transport = FakeTransport(connectError = error) + val conn = createTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + + conn.connect() + + assertFalse(conn.isActive) + // Cannot reconnect after failure (state is CLOSED) + assertFailsWith { conn.connect() } + } + + @Test + fun `server close renders connection inactive`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + assertTrue(conn.isActive) + transport.closeFromServer() + advanceUntilIdle() + + assertFalse(conn.isActive) + } + + @Test + fun `disconnect from never connected is no op`() = runTest { + val transport = FakeTransport() + val conn = createTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + + // Should not throw + conn.disconnect() + assertFalse(conn.isActive) + } + + @Test + fun `disconnect after connect renders inactive`() = runTest { + val transport = FakeTransport() + val conn = createTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + conn.connect() + assertTrue(conn.isActive) + + conn.disconnect() + advanceUntilIdle() + + assertFalse(conn.isActive) + } + + // ========================================================================= + // Post-Disconnect Operations — sendMessage returns false, caller cleans up + // ========================================================================= + + @Test + fun `call reducer after disconnect cleans up tracking`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.disconnect() + advanceUntilIdle() + + // sendMessage returns false — callback and tracker must be cleaned up + conn.callReducer("add", byteArrayOf(), "args") + assertEquals(0, conn.stats.reducerRequestTracker.requestsAwaitingResponse, + "Reducer tracker must be cleaned up when send fails") + } + + @Test + fun `call procedure after disconnect cleans up tracking`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.disconnect() + advanceUntilIdle() + + conn.callProcedure("proc", byteArrayOf()) + assertEquals(0, conn.stats.procedureRequestTracker.requestsAwaitingResponse, + "Procedure tracker must be cleaned up when send fails") + } + + @Test + fun `one off query after disconnect cleans up tracking`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.disconnect() + advanceUntilIdle() + + conn.oneOffQuery("SELECT 1") {} + assertEquals(0, conn.stats.oneOffRequestTracker.requestsAwaitingResponse, + "OneOffQuery tracker must be cleaned up when send fails") + } + + @Test + fun `subscribe after disconnect cleans up tracking`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.disconnect() + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + assertEquals(0, conn.stats.subscriptionRequestTracker.requestsAwaitingResponse, + "Subscription tracker must be cleaned up when send fails") + assertTrue(handle.isEnded, "Handle must be marked ended when send fails") + } + + // ========================================================================= + // Disconnect reason propagation + // ========================================================================= + + @Test + fun `disconnect with reason passes reason to callbacks`() = runTest { + val transport = FakeTransport() + var receivedReason: Throwable? = null + val conn = buildTestConnection(transport, onDisconnect = { _, err -> + receivedReason = err + }, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val reason = RuntimeException("intentional shutdown") + conn.disconnect(reason) + advanceUntilIdle() + + assertEquals(reason, receivedReason) + } + + @Test + fun `disconnect without reason passes null`() = runTest { + val transport = FakeTransport() + var receivedReason: Throwable? = Throwable("sentinel") + val conn = buildTestConnection(transport, onDisconnect = { _, err -> + receivedReason = err + }, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.disconnect() + advanceUntilIdle() + + assertNull(receivedReason) + } + + // ========================================================================= + // SubscriptionBuilder — subscribe(query) does NOT merge with addQuery() + // ========================================================================= + + @Test + fun `subscribe with query does not merge accumulated add query calls`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.subscriptionBuilder() + .addQuery("SELECT * FROM users") + .subscribe("SELECT * FROM messages") + advanceUntilIdle() + + val subMsg = transport.sentMessages.filterIsInstance().last() + assertEquals( + listOf("SELECT * FROM messages"), + subMsg.queryStrings, + "subscribe(query) must use only the passed query, ignoring addQuery() calls" + ) + conn.disconnect() + } + + @Test + fun `subscribe with list does not merge accumulated add query calls`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.subscriptionBuilder() + .addQuery("SELECT * FROM users") + .subscribe(listOf("SELECT * FROM messages", "SELECT * FROM notes")) + advanceUntilIdle() + + val subMsg = transport.sentMessages.filterIsInstance().last() + assertEquals( + listOf("SELECT * FROM messages", "SELECT * FROM notes"), + subMsg.queryStrings, + "subscribe(List) must use only the passed queries, ignoring addQuery() calls" + ) + conn.disconnect() + } + + // ========================================================================= + // Empty Subscription Queries + // ========================================================================= + + @Test + fun `subscribe with empty query list sends message`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(emptyList()) + advanceUntilIdle() + + val subMsg = transport.sentMessages.filterIsInstance().lastOrNull() + assertNotNull(subMsg) + assertTrue(subMsg.queryStrings.isEmpty()) + assertEquals(emptyList(), handle.queries) + conn.disconnect() + } + + // ========================================================================= + // SubscriptionHandle.queries stores original query strings + // ========================================================================= + + @Test + fun `subscription handle stores original queries`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val queries = listOf("SELECT * FROM users", "SELECT * FROM messages") + val handle = conn.subscribe(queries) + + assertEquals(queries, handle.queries) + conn.disconnect() + } + + // ========================================================================= + // Connect then immediate disconnect — state must end as Closed + // ========================================================================= + + @Test + fun `connect then immediate disconnect ends as closed`() = runTest { + val transport = FakeTransport() + val conn = createTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + + conn.connect() + assertTrue(conn.isActive) + + // Disconnect immediately without waiting for server handshake + conn.disconnect() + advanceUntilIdle() + + assertFalse(conn.isActive, "State must be Closed after disconnect, not stuck in Connected") + + // Must not be reconnectable — Closed is terminal + assertFailsWith { conn.connect() } + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/DisconnectScenarioTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/DisconnectScenarioTest.kt new file mode 100644 index 00000000000..36c49dc1cdf --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/DisconnectScenarioTest.kt @@ -0,0 +1,493 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.transport.Transport +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import kotlinx.coroutines.CoroutineExceptionHandler +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.StandardTestDispatcher +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertIs +import kotlin.test.assertTrue + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) +class DisconnectScenarioTest { + + // ========================================================================= + // Disconnect-During-Transaction Scenarios + // ========================================================================= + + @Test + fun `disconnect during pending one off query fails callback`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var callbackResult: OneOffQueryResult? = null + conn.oneOffQuery("SELECT * FROM sample") { result -> + callbackResult = result + } + advanceUntilIdle() + + // Disconnect before the server responds + conn.disconnect() + advanceUntilIdle() + + // Callback should have been invoked with an error + val result = assertNotNull(callbackResult) + assertIs>(result) + } + + @Test + fun `disconnect during pending suspend one off query throws`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var queryResult: OneOffQueryResult? = null + var queryError: Throwable? = null + launch { + try { + queryResult = conn.oneOffQuery("SELECT * FROM sample") + } catch (e: Throwable) { + queryError = e + } + } + advanceUntilIdle() + + conn.disconnect() + advanceUntilIdle() + + // The query must not hang silently — it must resolve on disconnect. + // failPendingOperations delivers an error result via the callback. + if (queryResult != null) { + assertIs>(queryResult, "Disconnect should produce SdkResult.Failure") + } else { + assertNotNull(queryError, "Suspended oneOffQuery must resolve on disconnect — got neither result nor error") + } + conn.disconnect() + } + + @Test + fun `server close during multiple pending operations`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Create multiple pending operations + val subHandle = conn.subscribe(listOf("SELECT * FROM t")) + var reducerFired = false + conn.callReducer("add", byteArrayOf(), "args", callback = { _ -> reducerFired = true }) + var queryResult: OneOffQueryResult? = null + conn.oneOffQuery("SELECT 1") { queryResult = it } + advanceUntilIdle() + + // Server closes connection + transport.closeFromServer() + advanceUntilIdle() + + // All pending operations should be cleaned up + assertTrue(subHandle.isEnded) + assertFalse(reducerFired) // Reducer callback never fires — it was discarded + val qResult = assertNotNull(queryResult) // One-off query callback fires with error + assertIs>(qResult) + } + + @Test + fun `transaction update during disconnect does not crash`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val row = SampleRow(1, "Alice") + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row.encode())))), + ) + ) + advanceUntilIdle() + + // Send a transaction update and immediately close + transport.sendToClient( + transactionUpdateMsg( + handle.querySetId, + "sample", + inserts = buildRowList(SampleRow(2, "Bob").encode()), + ) + ) + transport.closeFromServer() + advanceUntilIdle() + + // Should not crash — the transaction update may or may not have been processed + assertFalse(conn.isActive) + } + + // ========================================================================= + // Concurrent / racing disconnect + // ========================================================================= + + @Test + fun `disconnect while connecting does not crash`() = runTest { + // Use a transport that suspends forever in connect() + val suspendingTransport = object : Transport { + override suspend fun connect() { + kotlinx.coroutines.awaitCancellation() + } + override suspend fun send(message: ClientMessage) {} + override fun incoming(): kotlinx.coroutines.flow.Flow = + kotlinx.coroutines.flow.emptyFlow() + override suspend fun disconnect() {} + } + + val conn = DbConnection( + transport = suspendingTransport, + scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)), + onConnectCallbacks = emptyList(), + onDisconnectCallbacks = emptyList(), + onConnectErrorCallbacks = emptyList(), + clientConnectionId = ConnectionId.random(), + stats = Stats(), + moduleDescriptor = null, + callbackDispatcher = null, + ) + + // Start connecting in a background job — it will suspend in transport.connect() + val connectJob = launch { conn.connect() } + advanceUntilIdle() + + // Disconnect while connect() is still suspended + conn.disconnect() + advanceUntilIdle() + + assertFalse(conn.isActive) + connectJob.cancel() + } + + @Test + fun `multiple sequential disconnects fire callback only once`() = runTest { + val transport = FakeTransport() + var disconnectCount = 0 + val conn = buildTestConnection(transport, onDisconnect = { _, _ -> + disconnectCount++ + }, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + assertTrue(conn.isActive) + + // Three rapid sequential disconnects + conn.disconnect() + conn.disconnect() + conn.disconnect() + advanceUntilIdle() + + assertEquals(1, disconnectCount) + } + + @Test + fun `disconnect during subscribe applied processing`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + // Queue a SubscribeApplied then immediately disconnect + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows( + listOf(SingleTableRows("sample", buildRowList(SampleRow(1, "Alice").encode()))) + ), + ) + ) + conn.disconnect() + advanceUntilIdle() + + // Connection must be closed; cache state depends on timing but must be consistent + assertFalse(conn.isActive) + } + + @Test + fun `disconnect clears client cache completely`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows( + listOf( + SingleTableRows( + "sample", + buildRowList( + SampleRow(1, "Alice").encode(), + SampleRow(2, "Bob").encode(), + ) + ) + ) + ), + ) + ) + advanceUntilIdle() + assertEquals(2, cache.count()) + + conn.disconnect() + advanceUntilIdle() + + // disconnect() must clear the cache + assertEquals(0, cache.count()) + } + + @Test + fun `disconnect clears indexes consistently with cache`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + + val uniqueIndex = UniqueIndex(cache) { it.id } + val btreeIndex = BTreeIndex(cache) { it.name } + + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows( + listOf( + SingleTableRows( + "sample", + buildRowList( + SampleRow(1, "Alice").encode(), + SampleRow(2, "Bob").encode(), + ) + ) + ) + ), + ) + ) + advanceUntilIdle() + assertEquals(2, cache.count()) + assertNotNull(uniqueIndex.find(1)) + assertNotNull(uniqueIndex.find(2)) + assertEquals(1, btreeIndex.filter("Alice").size) + + // Send a transaction inserting a new row, then immediately disconnect. + // Before the fix, the receive loop could complete the CAS (adding the row + // and firing internal index listeners) but then disconnect() would clear + // _rows before the indexes were also cleared — leaving stale index entries. + transport.sendToClient( + transactionUpdateMsg( + handle.querySetId, + "sample", + inserts = buildRowList(SampleRow(3, "Charlie").encode()), + ) + ) + conn.disconnect() + advanceUntilIdle() + + // After disconnect, cache and indexes must be consistent: + // either both have the row or neither does. + assertEquals(0, cache.count(), "Cache should be cleared after disconnect") + assertNull(uniqueIndex.find(1), "UniqueIndex should be cleared after disconnect") + assertNull(uniqueIndex.find(2), "UniqueIndex should be cleared after disconnect") + assertNull(uniqueIndex.find(3), "UniqueIndex should not have stale entries after disconnect") + assertTrue(btreeIndex.filter("Alice").isEmpty(), "BTreeIndex should be cleared after disconnect") + assertTrue(btreeIndex.filter("Bob").isEmpty(), "BTreeIndex should be cleared after disconnect") + assertTrue(btreeIndex.filter("Charlie").isEmpty(), "BTreeIndex should not have stale entries after disconnect") + } + + @Test + fun `server close followed by client disconnect does not double fail pending`() = runTest { + val transport = FakeTransport() + var disconnectCount = 0 + val conn = buildTestConnection(transport, onDisconnect = { _, _ -> + disconnectCount++ + }, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Fire a reducer call so there's a pending callback + conn.callReducer("test", byteArrayOf(1), "args") + advanceUntilIdle() + + // Server closes, then client also calls disconnect + transport.closeFromServer() + conn.disconnect() + advanceUntilIdle() + + // Callback fires at most once + assertEquals(1, disconnectCount) + assertFalse(conn.isActive) + } + + // ========================================================================= + // Reconnection (new connection after old one is closed) + // ========================================================================= + + @Test + fun `fresh connection works after previous disconnect`() = runTest { + val transport1 = FakeTransport() + val conn1 = buildTestConnection(transport1, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport1.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + assertTrue(conn1.isActive) + assertEquals(TEST_IDENTITY, conn1.identity) + + conn1.disconnect() + advanceUntilIdle() + assertFalse(conn1.isActive) + + // Build a completely new connection (the "reconnect by rebuilding" pattern) + val transport2 = FakeTransport() + val secondIdentity = Identity(BigInteger.TEN) + val secondConnectionId = ConnectionId(BigInteger(20)) + var conn2ConnectFired = false + val conn2 = buildTestConnection(transport2, onConnect = { _, _, _ -> + conn2ConnectFired = true + }, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport2.sendToClient( + ServerMessage.InitialConnection( + identity = secondIdentity, + connectionId = secondConnectionId, + token = "new-token", + ) + ) + advanceUntilIdle() + + assertTrue(conn2.isActive) + assertTrue(conn2ConnectFired) + assertEquals(secondIdentity, conn2.identity) + + // Old connection must remain closed + assertFalse(conn1.isActive) + conn2.disconnect() + } + + @Test + fun `fresh connection cache is independent from old`() = runTest { + val transport1 = FakeTransport() + val conn1 = buildTestConnection(transport1, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + val cache1 = createSampleCache() + conn1.clientCache.register("sample", cache1) + transport1.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Insert a row via first connection + val handle1 = conn1.subscribe(listOf("SELECT * FROM sample")) + transport1.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle1.querySetId, + rows = QueryRows( + listOf(SingleTableRows("sample", buildRowList(SampleRow(1, "Alice").encode()))) + ), + ) + ) + advanceUntilIdle() + assertEquals(1, cache1.count()) + + conn1.disconnect() + advanceUntilIdle() + + // Second connection has its own empty cache + val transport2 = FakeTransport() + val conn2 = buildTestConnection(transport2, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + val cache2 = createSampleCache() + conn2.clientCache.register("sample", cache2) + transport2.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + assertEquals(0, cache2.count()) + conn2.disconnect() + } + + // ========================================================================= + // sendMessage after disconnect — graceful failure (no crash) + // ========================================================================= + + @Test + fun `send message after disconnect does not crash`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + assertTrue(conn.isActive) + + conn.disconnect() + advanceUntilIdle() + assertFalse(conn.isActive) + + // Attempting to send after disconnect logs a warning and returns — no throw + conn.callReducer("add", byteArrayOf(), "args") + // No exception means success + } + + @Test + fun `send message on closed channel does not crash`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + assertTrue(conn.isActive) + + // Server closes the connection + transport.closeFromServer() + advanceUntilIdle() + + // Any send attempt after server close logs a warning — no throw + conn.oneOffQuery("SELECT 1") {} + } + + @Test + fun `reducer callback does not fire on failed send`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.disconnect() + advanceUntilIdle() + + // callReducer returns without throwing — the callback is registered but + // will never fire since the message was not sent and the connection is closed. + var callbackFired = false + conn.callReducer("add", byteArrayOf(), "args", callback = { _ -> + callbackFired = true + }) + advanceUntilIdle() + + assertFalse(callbackFired) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/FakeTransport.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/FakeTransport.kt new file mode 100644 index 00000000000..0c89231dd64 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/FakeTransport.kt @@ -0,0 +1,63 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ClientMessage +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ServerMessage +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.transport.Transport +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.persistentListOf +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.consumeAsFlow + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class, kotlinx.coroutines.DelicateCoroutinesApi::class) +internal class FakeTransport( + private val connectError: Throwable? = null, +) : Transport { + private var _incoming = Channel(Channel.UNLIMITED) + private val _sent = atomic(persistentListOf()) + private val _sendError = atomic(null) + private var _connected = false + + override suspend fun connect() { + connectError?.let { throw it } + // Recreate channel on reconnect (closed channels can't be reused) + if (_incoming.isClosedForSend) { + _incoming = Channel(Channel.UNLIMITED) + } + _connected = true + } + + override suspend fun send(message: ClientMessage) { + _sendError.value?.let { throw it } + _sent.update { it.add(message) } + } + + override fun incoming(): Flow = _incoming.consumeAsFlow() + + override suspend fun disconnect() { + _connected = false + _incoming.close() + } + + val sentMessages: List get() = _sent.value + + suspend fun sendToClient(message: ServerMessage) { + _incoming.send(message) + } + + /** Close the incoming channel normally (flow completes, onDisconnect fires with null error). */ + fun closeFromServer() { + _incoming.close() + } + + /** Close the incoming channel with an error (flow throws, onDisconnect fires with the error). */ + fun closeWithError(cause: Throwable) { + _incoming.close(cause) + } + + /** When set, subsequent [send] calls throw this error (simulates send-path failure). */ + var sendError: Throwable? + get() = _sendError.value + set(value) { _sendError.value = value } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/IndexTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/IndexTest.kt new file mode 100644 index 00000000000..39d5aaf3b25 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/IndexTest.kt @@ -0,0 +1,157 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull + +class IndexTest { + + // ---- UniqueIndex ---- + + @Test + fun `unique index find returns correct row`() { + val cache = createSampleCache() + val alice = SampleRow(1, "alice") + val bob = SampleRow(2, "bob") + cache.applyInserts(STUB_CTX, buildRowList(alice.encode(), bob.encode())) + + val index = UniqueIndex(cache) { it.id } + assertEquals(alice, index.find(1)) + assertEquals(bob, index.find(2)) + assertNull(index.find(99)) + } + + @Test + fun `unique index tracks inserts`() { + val cache = createSampleCache() + val index = UniqueIndex(cache) { it.id } + + assertNull(index.find(1)) + + val alice = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(alice.encode())) + + assertEquals(alice, index.find(1)) + } + + @Test + fun `unique index tracks deletes`() { + val cache = createSampleCache() + val alice = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(alice.encode())) + + val index = UniqueIndex(cache) { it.id } + assertEquals(alice, index.find(1)) + + val parsed = cache.parseDeletes(buildRowList(alice.encode())) + cache.applyDeletes(STUB_CTX, parsed) + + assertNull(index.find(1)) + } + + // ---- BTreeIndex ---- + + @Test + fun `btree index filter returns all matching`() { + val cache = createSampleCache() + val alice = SampleRow(1, "alice") + val bob = SampleRow(2, "bob") + val charlie = SampleRow(3, "alice") + cache.applyInserts(STUB_CTX, buildRowList(alice.encode(), bob.encode(), charlie.encode())) + + val index = BTreeIndex(cache) { it.name } + val alices = index.filter("alice").sortedBy { it.id } + assertEquals(listOf(alice, charlie), alices) + assertEquals(setOf(bob), index.filter("bob")) + assertEquals(emptySet(), index.filter("nobody")) + } + + @Test + fun `btree index handles duplicate keys`() { + val cache = createSampleCache() + val r1 = SampleRow(1, "same") + val r2 = SampleRow(2, "same") + val r3 = SampleRow(3, "same") + cache.applyInserts(STUB_CTX, buildRowList(r1.encode(), r2.encode(), r3.encode())) + + val index = BTreeIndex(cache) { it.name } + assertEquals(3, index.filter("same").size) + } + + @Test + fun `btree index tracks inserts`() { + val cache = createSampleCache() + val index = BTreeIndex(cache) { it.name } + + assertEquals(emptySet(), index.filter("alice")) + + val alice = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(alice.encode())) + + assertEquals(setOf(alice), index.filter("alice")) + } + + @Test + fun `btree index removes empty key on delete`() { + val cache = createSampleCache() + val alice = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(alice.encode())) + + val index = BTreeIndex(cache) { it.name } + assertEquals(setOf(alice), index.filter("alice")) + + val parsed = cache.parseDeletes(buildRowList(alice.encode())) + cache.applyDeletes(STUB_CTX, parsed) + + assertEquals(emptySet(), index.filter("alice")) + } + + @Test + fun `btree index partial delete keeps remaining rows`() { + val cache = createSampleCache() + val r1 = SampleRow(1, "group") + val r2 = SampleRow(2, "group") + cache.applyInserts(STUB_CTX, buildRowList(r1.encode(), r2.encode())) + + val index = BTreeIndex(cache) { it.name } + assertEquals(2, index.filter("group").size) + + val parsed = cache.parseDeletes(buildRowList(r1.encode())) + cache.applyDeletes(STUB_CTX, parsed) + + val remaining = index.filter("group") + assertEquals(1, remaining.size) + assertEquals(r2, remaining.single()) + } + + // ---- Null key handling ---- + + @Test + fun `unique index handles null keys`() { + val cache = createSampleCache() + val nullKeyRow = SampleRow(0, "null-key") + val normalRow = SampleRow(1, "normal") + cache.applyInserts(STUB_CTX, buildRowList(nullKeyRow.encode(), normalRow.encode())) + + // Key extractor returns null for id == 0 + val index = UniqueIndex(cache) { if (it.id == 0) null else it.id } + assertEquals(nullKeyRow, index.find(null)) + assertEquals(normalRow, index.find(1)) + assertNull(index.find(99)) + } + + @Test + fun `btree index handles null keys`() { + val cache = createSampleCache() + val r1 = SampleRow(0, "a") + val r2 = SampleRow(1, "b") + val r3 = SampleRow(2, "c") + cache.applyInserts(STUB_CTX, buildRowList(r1.encode(), r2.encode(), r3.encode())) + + // Key extractor returns null for id == 0 + val index = BTreeIndex(cache) { if (it.id == 0) null else it.id } + assertEquals(setOf(r1), index.filter(null)) + assertEquals(setOf(r2), index.filter(1)) + assertEquals(emptySet(), index.filter(99)) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/IntegrationTestHelpers.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/IntegrationTestHelpers.kt new file mode 100644 index 00000000000..7163842a305 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/IntegrationTestHelpers.kt @@ -0,0 +1,109 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.transport.Transport +import kotlinx.coroutines.CoroutineExceptionHandler +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.test.StandardTestDispatcher +import kotlinx.coroutines.test.TestScope + +val TEST_IDENTITY = Identity(BigInteger.ONE) +val TEST_CONNECTION_ID = ConnectionId(BigInteger.TWO) +const val TEST_TOKEN = "test-token-abc" + +fun initialConnectionMsg() = ServerMessage.InitialConnection( + identity = TEST_IDENTITY, + connectionId = TEST_CONNECTION_ID, + token = TEST_TOKEN, +) + +internal suspend fun TestScope.buildTestConnection( + transport: FakeTransport, + onConnect: ((DbConnectionView, Identity, String) -> Unit)? = null, + onDisconnect: ((DbConnectionView, Throwable?) -> Unit)? = null, + onConnectError: ((DbConnectionView, Throwable) -> Unit)? = null, + moduleDescriptor: ModuleDescriptor? = null, + callbackDispatcher: kotlinx.coroutines.CoroutineDispatcher? = null, + exceptionHandler: CoroutineExceptionHandler? = null, +): DbConnection { + val conn = createTestConnection(transport, onConnect, onDisconnect, onConnectError, moduleDescriptor, callbackDispatcher, exceptionHandler) + conn.connect() + return conn +} + +internal fun TestScope.createTestConnection( + transport: FakeTransport, + onConnect: ((DbConnectionView, Identity, String) -> Unit)? = null, + onDisconnect: ((DbConnectionView, Throwable?) -> Unit)? = null, + onConnectError: ((DbConnectionView, Throwable) -> Unit)? = null, + moduleDescriptor: ModuleDescriptor? = null, + callbackDispatcher: kotlinx.coroutines.CoroutineDispatcher? = null, + exceptionHandler: CoroutineExceptionHandler? = null, +): DbConnection { + val baseContext = SupervisorJob() + StandardTestDispatcher(testScheduler) + val context = if (exceptionHandler != null) baseContext + exceptionHandler else baseContext + return DbConnection( + transport = transport, + scope = CoroutineScope(context), + onConnectCallbacks = listOfNotNull(onConnect), + onDisconnectCallbacks = listOfNotNull(onDisconnect), + onConnectErrorCallbacks = listOfNotNull(onConnectError), + clientConnectionId = ConnectionId.random(), + stats = Stats(), + moduleDescriptor = moduleDescriptor, + callbackDispatcher = callbackDispatcher, + ) +} + +internal fun TestScope.createConnectionWithTransport( + transport: Transport, + onDisconnect: ((DbConnectionView, Throwable?) -> Unit)? = null, +): DbConnection { + return DbConnection( + transport = transport, + scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)), + onConnectCallbacks = emptyList(), + onDisconnectCallbacks = listOfNotNull(onDisconnect), + onConnectErrorCallbacks = emptyList(), + clientConnectionId = ConnectionId.random(), + stats = Stats(), + moduleDescriptor = null, + callbackDispatcher = null, + ) +} + +fun emptyQueryRows(): QueryRows = QueryRows(emptyList()) + +fun transactionUpdateMsg( + querySetId: QuerySetId, + tableName: String, + inserts: BsatnRowList = buildRowList(), + deletes: BsatnRowList = buildRowList(), +) = ServerMessage.TransactionUpdateMsg( + TransactionUpdate( + listOf( + QuerySetUpdate( + querySetId, + listOf( + TableUpdate( + tableName, + listOf(TableUpdateRows.PersistentTable(inserts, deletes)) + ) + ) + ) + ) + ) +) + +fun encodeInitialConnectionBytes(): ByteArray { + val writer = BsatnWriter() + writer.writeSumTag(0u) // InitialConnection tag + TEST_IDENTITY.encode(writer) + TEST_CONNECTION_ID.encode(writer) + writer.writeString(TEST_TOKEN) + return writer.toByteArray() +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/LoggerTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/LoggerTest.kt new file mode 100644 index 00000000000..cf658a88c31 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/LoggerTest.kt @@ -0,0 +1,133 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import kotlin.test.AfterTest + +class LoggerTest { + private val originalLevel = Logger.level + private val originalHandler = Logger.handler + + @AfterTest + fun restoreLogger() { + Logger.level = originalLevel + Logger.handler = originalHandler + } + + // ---- Redaction ---- + + @Test + fun `redacts token equals`() { + val messages = mutableListOf() + Logger.level = LogLevel.INFO + Logger.handler = LogHandler { _, msg -> messages.add(msg) } + + Logger.info { "Connecting with token=secret123 to server" } + + assertEquals(1, messages.size) + assertTrue(messages[0].contains("[REDACTED]"), "Token value should be redacted") + assertFalse(messages[0].contains("secret123"), "Original secret should not appear") + } + + @Test + fun `redacts token colon`() { + val messages = mutableListOf() + Logger.level = LogLevel.INFO + Logger.handler = LogHandler { _, msg -> messages.add(msg) } + + Logger.info { "token: mySecretValue" } + + assertTrue(messages[0].contains("[REDACTED]")) + assertFalse(messages[0].contains("mySecretValue")) + } + + @Test + fun `redacts case insensitive`() { + val messages = mutableListOf() + Logger.level = LogLevel.INFO + Logger.handler = LogHandler { _, msg -> messages.add(msg) } + + Logger.info { "TOKEN=abc123" } + Logger.info { "Token=def456" } + Logger.info { "PASSWORD=hunter2" } + + assertEquals(3, messages.size) + for (msg in messages) { + assertTrue(msg.contains("[REDACTED]"), "Should redact: $msg") + } + } + + @Test + fun `redacts multiple patterns in one message`() { + val messages = mutableListOf() + Logger.level = LogLevel.INFO + Logger.handler = LogHandler { _, msg -> messages.add(msg) } + + Logger.info { "token=abc password=xyz" } + + assertEquals(1, messages.size) + assertFalse(messages[0].contains("abc"), "First secret should be redacted") + assertFalse(messages[0].contains("xyz"), "Second secret should be redacted") + } + + @Test + fun `non sensitive passes through`() { + val messages = mutableListOf() + Logger.level = LogLevel.INFO + Logger.handler = LogHandler { _, msg -> messages.add(msg) } + + Logger.info { "Connected to database on port 3000" } + + assertEquals(1, messages.size) + assertEquals("Connected to database on port 3000", messages[0]) + } + + // ---- Log level filtering ---- + + @Test + fun `should log ordinal logic`() { + // EXCEPTION(0) should log at any level + assertTrue(LogLevel.EXCEPTION.shouldLog(LogLevel.EXCEPTION)) + assertTrue(LogLevel.EXCEPTION.shouldLog(LogLevel.TRACE)) + + // TRACE(5) should only log at TRACE level + assertTrue(LogLevel.TRACE.shouldLog(LogLevel.TRACE)) + assertFalse(LogLevel.TRACE.shouldLog(LogLevel.INFO)) + assertFalse(LogLevel.TRACE.shouldLog(LogLevel.EXCEPTION)) + } + + @Test + fun `log level filters suppresses lower priority`() { + val messages = mutableListOf() + Logger.level = LogLevel.WARN + Logger.handler = LogHandler { lvl, _ -> messages.add(lvl) } + + Logger.error { "error" } // should log (ERROR < WARN in ordinal) + Logger.warn { "warn" } // should log (WARN == WARN) + Logger.info { "info" } // should NOT log (INFO > WARN in ordinal) + Logger.debug { "debug" } // should NOT log + Logger.trace { "trace" } // should NOT log + + assertEquals(listOf(LogLevel.ERROR, LogLevel.WARN), messages) + } + + // ---- Custom handler ---- + + @Test + fun `custom handler receives correct level and message`() { + var capturedLevel: LogLevel? = null + var capturedMessage: String? = null + Logger.level = LogLevel.DEBUG + Logger.handler = LogHandler { lvl, msg -> + capturedLevel = lvl + capturedMessage = msg + } + + Logger.debug { "test message" } + + assertEquals(LogLevel.DEBUG, capturedLevel) + assertEquals("test message", capturedMessage) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ProcedureAndQueryIntegrationTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ProcedureAndQueryIntegrationTest.kt new file mode 100644 index 00000000000..f44476fa2a8 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ProcedureAndQueryIntegrationTest.kt @@ -0,0 +1,295 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.TimeDuration +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) +class ProcedureAndQueryIntegrationTest { + + // --- Procedures --- + + @Test + fun `call procedure sends client message`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.callProcedure("my_proc", byteArrayOf(42)) + advanceUntilIdle() + + val procMsg = transport.sentMessages.filterIsInstance().firstOrNull() + assertNotNull(procMsg) + assertEquals("my_proc", procMsg.procedure) + assertTrue(procMsg.args.contentEquals(byteArrayOf(42))) + conn.disconnect() + } + + @Test + fun `procedure result fires callback`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var receivedStatus: ProcedureStatus? = null + val requestId = conn.callProcedure( + procedureName = "my_proc", + args = byteArrayOf(), + callback = { _, msg -> receivedStatus = msg.status }, + ) + advanceUntilIdle() + + transport.sendToClient( + ServerMessage.ProcedureResultMsg( + status = ProcedureStatus.Returned(byteArrayOf(1, 2, 3)), + timestamp = Timestamp.UNIX_EPOCH, + totalHostExecutionDuration = TimeDuration(Duration.ZERO), + requestId = requestId, + ) + ) + advanceUntilIdle() + + assertTrue(receivedStatus is ProcedureStatus.Returned) + conn.disconnect() + } + + @Test + fun `procedure result internal error fires callback`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var receivedStatus: ProcedureStatus? = null + val requestId = conn.callProcedure( + procedureName = "bad_proc", + args = byteArrayOf(), + callback = { _, msg -> receivedStatus = msg.status }, + ) + advanceUntilIdle() + + transport.sendToClient( + ServerMessage.ProcedureResultMsg( + status = ProcedureStatus.InternalError("proc failed"), + timestamp = Timestamp.UNIX_EPOCH, + totalHostExecutionDuration = TimeDuration(Duration.ZERO), + requestId = requestId, + ) + ) + advanceUntilIdle() + + assertTrue(receivedStatus is ProcedureStatus.InternalError) + assertEquals("proc failed", (receivedStatus as ProcedureStatus.InternalError).message) + conn.disconnect() + } + + // --- One-off queries --- + + @Test + fun `one off query callback receives result`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var result: OneOffQueryResult? = null + val requestId = conn.oneOffQuery("SELECT * FROM sample") { msg -> + result = msg + } + advanceUntilIdle() + + transport.sendToClient( + ServerMessage.OneOffQueryResult( + requestId = requestId, + result = QueryResult.Ok(emptyQueryRows()), + ) + ) + advanceUntilIdle() + + val capturedResult = result + assertNotNull(capturedResult) + assertTrue(capturedResult is SdkResult.Success) + conn.disconnect() + } + + @Test + fun `one off query suspend returns result`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Retrieve the requestId that will be assigned by inspecting sentMessages + val beforeCount = transport.sentMessages.size + // Launch the suspend query in a separate coroutine since it suspends + var queryResult: OneOffQueryResult? = null + launch { + queryResult = conn.oneOffQuery("SELECT * FROM sample") + } + advanceUntilIdle() + + // Find the OneOffQuery message + val queryMsg = transport.sentMessages.drop(beforeCount) + .filterIsInstance().firstOrNull() + assertNotNull(queryMsg) + + transport.sendToClient( + ServerMessage.OneOffQueryResult( + requestId = queryMsg.requestId, + result = QueryResult.Ok(emptyQueryRows()), + ) + ) + advanceUntilIdle() + + val capturedQueryResult = queryResult + assertNotNull(capturedQueryResult) + assertTrue(capturedQueryResult is SdkResult.Success) + conn.disconnect() + } + + // --- One-off query error --- + + @Test + fun `one off query callback receives error`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var result: OneOffQueryResult? = null + val requestId = conn.oneOffQuery("SELECT * FROM bad") { msg -> + result = msg + } + advanceUntilIdle() + + transport.sendToClient( + ServerMessage.OneOffQueryResult( + requestId = requestId, + result = QueryResult.Err("syntax error"), + ) + ) + advanceUntilIdle() + + val capturedResult = result + assertNotNull(capturedResult) + assertTrue(capturedResult is SdkResult.Failure) + val queryError = capturedResult.error + assertTrue(queryError is QueryError.ServerError) + assertEquals("syntax error", queryError.message) + conn.disconnect() + } + + // --- oneOffQuery cancellation --- + + @Test + fun `one off query suspend cancellation cleans up callback`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val job = launch { + conn.oneOffQuery("SELECT * FROM sample") // will suspend forever + } + advanceUntilIdle() + + // Cancel the coroutine — should clean up the callback + job.cancel() + advanceUntilIdle() + + // Now send a result for that requestId — should not crash + val queryMsg = transport.sentMessages.filterIsInstance().lastOrNull() + assertNotNull(queryMsg) + transport.sendToClient( + ServerMessage.OneOffQueryResult( + requestId = queryMsg.requestId, + result = QueryResult.Ok(emptyQueryRows()), + ) + ) + advanceUntilIdle() + + assertTrue(conn.isActive) + conn.disconnect() + } + + // --- oneOffQuery suspend with finite timeout --- + + @Test + fun `one off query suspend times out when no response`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + assertFailsWith { + conn.oneOffQuery("SELECT * FROM sample", timeout = 1.milliseconds) + } + + conn.disconnect() + } + + // --- callProcedure without callback (fire-and-forget) --- + + @Test + fun `call procedure without callback sends message`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.callProcedure("myProc", byteArrayOf(), callback = null) + advanceUntilIdle() + + val sent = transport.sentMessages.filterIsInstance() + assertEquals(1, sent.size) + assertEquals("myProc", sent[0].procedure) + + // Sending a result for it should not crash (no callback registered) + transport.sendToClient( + ServerMessage.ProcedureResultMsg( + requestId = sent[0].requestId, + timestamp = Timestamp.UNIX_EPOCH, + status = ProcedureStatus.Returned(byteArrayOf()), + totalHostExecutionDuration = TimeDuration(Duration.ZERO), + ) + ) + advanceUntilIdle() + + assertTrue(conn.isActive) + conn.disconnect() + } + + // --- Procedure result before identity is set --- + + @Test + fun `procedure result before identity set is ignored`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + // Do NOT send InitialConnection — identity stays null + + transport.sendToClient( + ServerMessage.ProcedureResultMsg( + requestId = 1u, + timestamp = Timestamp.UNIX_EPOCH, + status = ProcedureStatus.Returned(byteArrayOf()), + totalHostExecutionDuration = TimeDuration(Duration.ZERO), + ) + ) + advanceUntilIdle() + + assertTrue(conn.isActive) + conn.disconnect() + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ProtocolDecodeTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ProtocolDecodeTest.kt new file mode 100644 index 00000000000..b9cf369b1bd --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ProtocolDecodeTest.kt @@ -0,0 +1,326 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.BsatnRowList +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.DecompressedPayload +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ProcedureStatus +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QueryRows +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ReducerOutcome +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.RowSizeHint +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.SingleTableRows +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.TableUpdateRows +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertIs +import kotlin.test.assertTrue + +class ProtocolDecodeTest { + + // ---- RowSizeHint ---- + + @Test + fun `row size hint fixed size decode`() { + val writer = BsatnWriter() + writer.writeSumTag(0u) // tag = FixedSize + writer.writeU16(4u) // 4 bytes per row + + val hint = RowSizeHint.decode(BsatnReader(writer.toByteArray())) + val fixed = assertIs(hint) + assertEquals(4u.toUShort(), fixed.size) + } + + @Test + fun `row size hint row offsets decode`() { + val writer = BsatnWriter() + writer.writeSumTag(1u) // tag = RowOffsets + writer.writeArrayLen(3) + writer.writeU64(0uL) + writer.writeU64(10uL) + writer.writeU64(25uL) + + val hint = RowSizeHint.decode(BsatnReader(writer.toByteArray())) + val offsets = assertIs(hint) + assertEquals(listOf(0uL, 10uL, 25uL), offsets.offsets) + } + + @Test + fun `row size hint unknown tag throws`() { + val writer = BsatnWriter() + writer.writeSumTag(99u) // invalid tag + + assertFailsWith { + RowSizeHint.decode(BsatnReader(writer.toByteArray())) + } + } + + // ---- BsatnRowList ---- + + @Test + fun `bsatn row list decode with fixed size`() { + val writer = BsatnWriter() + // RowSizeHint::FixedSize(4) + writer.writeSumTag(0u) + writer.writeU16(4u) + // Rows data: U32 length prefix + raw bytes + val rowData = byteArrayOf(1, 2, 3, 4, 5, 6, 7, 8) // 2 rows of 4 bytes + writer.writeU32(rowData.size.toUInt()) + writer.writeRawBytes(rowData) + + val rowList = BsatnRowList.decode(BsatnReader(writer.toByteArray())) + assertIs(rowList.sizeHint) + assertEquals(8, rowList.rowsSize) + } + + @Test + fun `bsatn row list decode with row offsets`() { + val writer = BsatnWriter() + // RowSizeHint::RowOffsets([0, 5]) + writer.writeSumTag(1u) + writer.writeArrayLen(2) + writer.writeU64(0uL) + writer.writeU64(5uL) + // Rows data + val rowData = byteArrayOf(10, 20, 30, 40, 50, 60, 70, 80, 90) + writer.writeU32(rowData.size.toUInt()) + writer.writeRawBytes(rowData) + + val rowList = BsatnRowList.decode(BsatnReader(writer.toByteArray())) + assertIs(rowList.sizeHint) + assertEquals(9, rowList.rowsSize) + } + + @Test + fun `bsatn row list decode overflow length throws`() { + val writer = BsatnWriter() + // RowSizeHint::FixedSize(4) + writer.writeSumTag(0u) + writer.writeU16(4u) + // Length that overflows Int: 0x80000000 (2,147,483,648) + writer.writeU32(0x8000_0000u) + // No actual row data — the check should fire before reading + + assertFailsWith { + BsatnRowList.decode(BsatnReader(writer.toByteArray())) + } + } + + // ---- SingleTableRows ---- + + @Test + fun `single table rows decode`() { + val writer = BsatnWriter() + writer.writeString("Players") + // BsatnRowList: FixedSize(4), 4 bytes of data + writer.writeSumTag(0u) + writer.writeU16(4u) + writer.writeU32(4u) + writer.writeRawBytes(byteArrayOf(0, 0, 0, 42)) + + val rows = SingleTableRows.decode(BsatnReader(writer.toByteArray())) + assertEquals("Players", rows.table) + assertEquals(4, rows.rows.rowsSize) + } + + // ---- QueryRows ---- + + @Test + fun `query rows decode empty`() { + val writer = BsatnWriter() + writer.writeArrayLen(0) + + val qr = QueryRows.decode(BsatnReader(writer.toByteArray())) + assertTrue(qr.tables.isEmpty()) + } + + @Test + fun `query rows decode with tables`() { + val writer = BsatnWriter() + writer.writeArrayLen(2) + // Table 1 + writer.writeString("Players") + writer.writeSumTag(0u); writer.writeU16(4u) // FixedSize(4) + writer.writeU32(0u) // 0 bytes of row data + // Table 2 + writer.writeString("Items") + writer.writeSumTag(0u); writer.writeU16(8u) // FixedSize(8) + writer.writeU32(0u) // 0 bytes of row data + + val qr = QueryRows.decode(BsatnReader(writer.toByteArray())) + assertEquals(2, qr.tables.size) + assertEquals("Players", qr.tables[0].table) + assertEquals("Items", qr.tables[1].table) + } + + // ---- TableUpdateRows ---- + + @Test + fun `table update rows persistent table decode`() { + val writer = BsatnWriter() + writer.writeSumTag(0u) // tag = PersistentTable + // inserts: BsatnRowList + writer.writeSumTag(0u); writer.writeU16(4u) // FixedSize(4) + writer.writeU32(4u) + writer.writeRawBytes(byteArrayOf(1, 0, 0, 0)) // one I32 row + // deletes: BsatnRowList + writer.writeSumTag(0u); writer.writeU16(4u) // FixedSize(4) + writer.writeU32(0u) // no deletes + + val update = TableUpdateRows.decode(BsatnReader(writer.toByteArray())) + val pt = assertIs(update) + assertEquals(4, pt.inserts.rowsSize) + assertEquals(0, pt.deletes.rowsSize) + } + + @Test + fun `table update rows event table decode`() { + val writer = BsatnWriter() + writer.writeSumTag(1u) // tag = EventTable + // events: BsatnRowList + writer.writeSumTag(0u); writer.writeU16(4u) // FixedSize(4) + writer.writeU32(8u) + writer.writeRawBytes(byteArrayOf(1, 0, 0, 0, 2, 0, 0, 0)) + + val update = TableUpdateRows.decode(BsatnReader(writer.toByteArray())) + val et = assertIs(update) + assertEquals(8, et.events.rowsSize) + } + + @Test + fun `table update rows unknown tag throws`() { + val writer = BsatnWriter() + writer.writeSumTag(99u) + + assertFailsWith { + TableUpdateRows.decode(BsatnReader(writer.toByteArray())) + } + } + + // ---- ReducerOutcome ---- + + @Test + fun `reducer outcome ok decode`() { + val writer = BsatnWriter() + writer.writeSumTag(0u) // tag = Ok + writer.writeByteArray(byteArrayOf(42)) // retValue + writer.writeArrayLen(0) // empty TransactionUpdate + + val outcome = ReducerOutcome.decode(BsatnReader(writer.toByteArray())) + val ok = assertIs(outcome) + assertTrue(ok.retValue.contentEquals(byteArrayOf(42))) + assertTrue(ok.transactionUpdate.querySets.isEmpty()) + } + + @Test + fun `reducer outcome ok empty decode`() { + val writer = BsatnWriter() + writer.writeSumTag(1u) // tag = OkEmpty + + val outcome = ReducerOutcome.decode(BsatnReader(writer.toByteArray())) + assertIs(outcome) + } + + @Test + fun `reducer outcome err decode`() { + val writer = BsatnWriter() + writer.writeSumTag(2u) // tag = Err + writer.writeByteArray(byteArrayOf(0xDE.toByte())) + + val outcome = ReducerOutcome.decode(BsatnReader(writer.toByteArray())) + val err = assertIs(outcome) + assertTrue(err.error.contentEquals(byteArrayOf(0xDE.toByte()))) + } + + @Test + fun `reducer outcome internal error decode`() { + val writer = BsatnWriter() + writer.writeSumTag(3u) // tag = InternalError + writer.writeString("panic in reducer") + + val outcome = ReducerOutcome.decode(BsatnReader(writer.toByteArray())) + val err = assertIs(outcome) + assertEquals("panic in reducer", err.message) + } + + @Test + fun `reducer outcome unknown tag throws`() { + val writer = BsatnWriter() + writer.writeSumTag(99u) + + assertFailsWith { + ReducerOutcome.decode(BsatnReader(writer.toByteArray())) + } + } + + // ---- ProcedureStatus ---- + + @Test + fun `procedure status returned decode`() { + val writer = BsatnWriter() + writer.writeSumTag(0u) // tag = Returned + writer.writeByteArray(byteArrayOf(1, 2, 3)) + + val status = ProcedureStatus.decode(BsatnReader(writer.toByteArray())) + val returned = assertIs(status) + assertTrue(returned.value.contentEquals(byteArrayOf(1, 2, 3))) + } + + @Test + fun `procedure status internal error decode`() { + val writer = BsatnWriter() + writer.writeSumTag(1u) // tag = InternalError + writer.writeString("procedure crashed") + + val status = ProcedureStatus.decode(BsatnReader(writer.toByteArray())) + val err = assertIs(status) + assertEquals("procedure crashed", err.message) + } + + @Test + fun `procedure status unknown tag throws`() { + val writer = BsatnWriter() + writer.writeSumTag(99u) + + assertFailsWith { + ProcedureStatus.decode(BsatnReader(writer.toByteArray())) + } + } + + // ---- DecompressedPayload offset validation ---- + + @Test + fun `decompressed payload valid offset`() { + val data = byteArrayOf(1, 2, 3, 4) + val payload = DecompressedPayload(data, 1) + assertEquals(3, payload.size) + } + + @Test + fun `decompressed payload zero offset`() { + val data = byteArrayOf(1, 2, 3) + val payload = DecompressedPayload(data, 0) + assertEquals(3, payload.size) + } + + @Test + fun `decompressed payload offset at end`() { + val data = byteArrayOf(1, 2) + val payload = DecompressedPayload(data, 2) + assertEquals(0, payload.size) + } + + @Test + fun `decompressed payload negative offset rejects`() { + assertFailsWith { + DecompressedPayload(byteArrayOf(1, 2), -1) + } + } + + @Test + fun `decompressed payload offset beyond size rejects`() { + assertFailsWith { + DecompressedPayload(byteArrayOf(1, 2), 3) + } + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ProtocolRoundTripTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ProtocolRoundTripTest.kt new file mode 100644 index 00000000000..a1eb5c92173 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ProtocolRoundTripTest.kt @@ -0,0 +1,531 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.TimeDuration +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.time.Duration + +/** Encode→decode round-trip tests for ClientMessage and ServerMessage. */ +class ProtocolRoundTripTest { + + // ---- ClientMessage round-trips (encode → decode → assertEquals) ---- + + @Test + fun `client message subscribe round trip`() { + val original = ClientMessage.Subscribe( + requestId = 42u, + querySetId = QuerySetId(7u), + queryStrings = listOf("SELECT * FROM player", "SELECT * FROM item WHERE owner = 1"), + ) + val decoded = roundTripClientMessage(original) + assertEquals(original, decoded) + } + + @Test + fun `client message subscribe empty queries round trip`() { + val original = ClientMessage.Subscribe( + requestId = 0u, + querySetId = QuerySetId(0u), + queryStrings = emptyList(), + ) + val decoded = roundTripClientMessage(original) + assertEquals(original, decoded) + } + + @Test + fun `client message unsubscribe default round trip`() { + val original = ClientMessage.Unsubscribe( + requestId = 10u, + querySetId = QuerySetId(3u), + flags = UnsubscribeFlags.Default, + ) + val decoded = roundTripClientMessage(original) + assertEquals(original, decoded) + } + + @Test + fun `client message unsubscribe send dropped rows round trip`() { + val original = ClientMessage.Unsubscribe( + requestId = 10u, + querySetId = QuerySetId(3u), + flags = UnsubscribeFlags.SendDroppedRows, + ) + val decoded = roundTripClientMessage(original) + assertEquals(original, decoded) + } + + @Test + fun `client message one off query round trip`() { + val original = ClientMessage.OneOffQuery( + requestId = 99u, + queryString = "SELECT count(*) FROM users", + ) + val decoded = roundTripClientMessage(original) + assertEquals(original, decoded) + } + + @Test + fun `client message call reducer round trip`() { + val original = ClientMessage.CallReducer( + requestId = 5u, + flags = 0u, + reducer = "add_player", + args = byteArrayOf(1, 2, 3, 4, 5), + ) + val decoded = roundTripClientMessage(original) + assertEquals(original, decoded) + } + + @Test + fun `client message call reducer empty args round trip`() { + val original = ClientMessage.CallReducer( + requestId = 0u, + flags = 1u, + reducer = "noop", + args = byteArrayOf(), + ) + val decoded = roundTripClientMessage(original) + assertEquals(original, decoded) + } + + @Test + fun `client message call procedure round trip`() { + val original = ClientMessage.CallProcedure( + requestId = 77u, + flags = 0u, + procedure = "get_leaderboard", + args = byteArrayOf(10, 20), + ) + val decoded = roundTripClientMessage(original) + assertEquals(original, decoded) + } + + // ---- ServerMessage round-trips (encode → decode → re-encode → assertContentEquals) ---- + // ServerMessage types containing BsatnRowList don't have value equality, + // so we verify encode→decode→re-encode produces identical bytes. + + @Test + fun `server message initial connection round trip`() { + val original = ServerMessage.InitialConnection( + identity = Identity(BigInteger.parseString("123456789ABCDEF", 16)), + connectionId = ConnectionId(BigInteger.parseString("FEDCBA987654321", 16)), + token = "my-auth-token", + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message subscribe applied round trip`() { + val original = ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = QuerySetId(5u), + rows = QueryRows(listOf( + SingleTableRows("player", buildRowList(SampleRow(1, "Alice").encode())), + )), + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message subscribe applied empty rows round trip`() { + val original = ServerMessage.SubscribeApplied( + requestId = 0u, + querySetId = QuerySetId(0u), + rows = QueryRows(emptyList()), + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message unsubscribe applied with rows round trip`() { + val original = ServerMessage.UnsubscribeApplied( + requestId = 2u, + querySetId = QuerySetId(3u), + rows = QueryRows(listOf( + SingleTableRows("item", buildRowList(SampleRow(42, "sword").encode())), + )), + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message unsubscribe applied null rows round trip`() { + val original = ServerMessage.UnsubscribeApplied( + requestId = 2u, + querySetId = QuerySetId(3u), + rows = null, + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message subscription error with request id round trip`() { + val original = ServerMessage.SubscriptionError( + requestId = 10u, + querySetId = QuerySetId(4u), + error = "table not found", + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message subscription error null request id round trip`() { + val original = ServerMessage.SubscriptionError( + requestId = null, + querySetId = QuerySetId(4u), + error = "fatal error", + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message transaction update round trip`() { + val row1 = SampleRow(1, "Alice").encode() + val row2 = SampleRow(2, "Bob").encode() + val original = ServerMessage.TransactionUpdateMsg( + TransactionUpdate(listOf( + QuerySetUpdate( + QuerySetId(1u), + listOf( + TableUpdate("player", listOf( + TableUpdateRows.PersistentTable( + inserts = buildRowList(row2), + deletes = buildRowList(row1), + ), + )), + ), + ), + )), + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message transaction update event table round trip`() { + val row = SampleRow(1, "event_data").encode() + val original = ServerMessage.TransactionUpdateMsg( + TransactionUpdate(listOf( + QuerySetUpdate( + QuerySetId(2u), + listOf( + TableUpdate("events", listOf( + TableUpdateRows.EventTable(events = buildRowList(row)), + )), + ), + ), + )), + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message one off query result ok round trip`() { + val original = ServerMessage.OneOffQueryResult( + requestId = 55u, + result = QueryResult.Ok(QueryRows(listOf( + SingleTableRows("users", buildRowList(SampleRow(1, "test").encode())), + ))), + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message one off query result err round trip`() { + val original = ServerMessage.OneOffQueryResult( + requestId = 55u, + result = QueryResult.Err("syntax error near 'SELEC'"), + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message reducer result ok round trip`() { + val original = ServerMessage.ReducerResultMsg( + requestId = 8u, + timestamp = Timestamp.fromEpochMicroseconds(1_700_000_000_000_000L), + result = ReducerOutcome.Ok( + retValue = byteArrayOf(42), + transactionUpdate = TransactionUpdate(emptyList()), + ), + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message reducer result ok empty round trip`() { + val original = ServerMessage.ReducerResultMsg( + requestId = 9u, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message reducer result err round trip`() { + val original = ServerMessage.ReducerResultMsg( + requestId = 10u, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.Err(byteArrayOf(0xDE.toByte(), 0xAD.toByte())), + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message reducer result internal error round trip`() { + val original = ServerMessage.ReducerResultMsg( + requestId = 11u, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.InternalError("internal server error"), + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message procedure result returned round trip`() { + val original = ServerMessage.ProcedureResultMsg( + status = ProcedureStatus.Returned(byteArrayOf(1, 2, 3)), + timestamp = Timestamp.fromEpochMicroseconds(1_000_000L), + totalHostExecutionDuration = TimeDuration(Duration.ZERO), + requestId = 20u, + ) + assertServerMessageRoundTrip(original) + } + + @Test + fun `server message procedure result internal error round trip`() { + val original = ServerMessage.ProcedureResultMsg( + status = ProcedureStatus.InternalError("proc failed"), + timestamp = Timestamp.UNIX_EPOCH, + totalHostExecutionDuration = TimeDuration(Duration.ZERO), + requestId = 21u, + ) + assertServerMessageRoundTrip(original) + } + + // ---- Helpers ---- + + /** Encode → decode round-trip for ClientMessage. Uses data class equals. */ + private fun roundTripClientMessage(original: ClientMessage): ClientMessage { + val bytes = ClientMessage.encodeToBytes(original) + return decodeClientMessage(BsatnReader(bytes)) + } + + /** + * Encode → decode → re-encode round-trip for ServerMessage. + * Asserts that the byte representation is identical after a round-trip. + */ + private fun assertServerMessageRoundTrip(original: ServerMessage) { + val bytes1 = encodeServerMessage(original) + val decoded = ServerMessage.decodeFromBytes(bytes1) + val bytes2 = encodeServerMessage(decoded) + assertContentEquals(bytes1, bytes2) + } + + // ---- Test-only decode for ClientMessage (inverse of ClientMessage.encode) ---- + + private fun decodeClientMessage(reader: BsatnReader): ClientMessage { + return when (val tag = reader.readSumTag().toInt()) { + 0 -> ClientMessage.Subscribe( + requestId = reader.readU32(), + querySetId = QuerySetId(reader.readU32()), + queryStrings = List(reader.readArrayLen()) { reader.readString() }, + ) + 1 -> ClientMessage.Unsubscribe( + requestId = reader.readU32(), + querySetId = QuerySetId(reader.readU32()), + flags = when (val ft = reader.readSumTag().toInt()) { + 0 -> UnsubscribeFlags.Default + 1 -> UnsubscribeFlags.SendDroppedRows + else -> error("Unknown UnsubscribeFlags tag: $ft") + }, + ) + 2 -> ClientMessage.OneOffQuery( + requestId = reader.readU32(), + queryString = reader.readString(), + ) + 3 -> ClientMessage.CallReducer( + requestId = reader.readU32(), + flags = reader.readU8(), + reducer = reader.readString(), + args = reader.readByteArray(), + ) + 4 -> ClientMessage.CallProcedure( + requestId = reader.readU32(), + flags = reader.readU8(), + procedure = reader.readString(), + args = reader.readByteArray(), + ) + else -> error("Unknown ClientMessage tag: $tag") + } + } + + // ---- Test-only encode for ServerMessage (inverse of ServerMessage.decode) ---- + + private fun encodeServerMessage(msg: ServerMessage): ByteArray { + val writer = BsatnWriter() + when (msg) { + is ServerMessage.InitialConnection -> { + writer.writeSumTag(0u) + msg.identity.encode(writer) + msg.connectionId.encode(writer) + writer.writeString(msg.token) + } + is ServerMessage.SubscribeApplied -> { + writer.writeSumTag(1u) + writer.writeU32(msg.requestId) + writer.writeU32(msg.querySetId.id) + encodeQueryRows(writer, msg.rows) + } + is ServerMessage.UnsubscribeApplied -> { + writer.writeSumTag(2u) + writer.writeU32(msg.requestId) + writer.writeU32(msg.querySetId.id) + if (msg.rows != null) { + writer.writeSumTag(0u) // Some + encodeQueryRows(writer, msg.rows) + } else { + writer.writeSumTag(1u) // None + } + } + is ServerMessage.SubscriptionError -> { + writer.writeSumTag(3u) + if (msg.requestId != null) { + writer.writeSumTag(0u) // Some + writer.writeU32(msg.requestId) + } else { + writer.writeSumTag(1u) // None + } + writer.writeU32(msg.querySetId.id) + writer.writeString(msg.error) + } + is ServerMessage.TransactionUpdateMsg -> { + writer.writeSumTag(4u) + encodeTransactionUpdate(writer, msg.update) + } + is ServerMessage.OneOffQueryResult -> { + writer.writeSumTag(5u) + writer.writeU32(msg.requestId) + when (val r = msg.result) { + is QueryResult.Ok -> { + writer.writeSumTag(0u) + encodeQueryRows(writer, r.rows) + } + is QueryResult.Err -> { + writer.writeSumTag(1u) + writer.writeString(r.error) + } + } + } + is ServerMessage.ReducerResultMsg -> { + writer.writeSumTag(6u) + writer.writeU32(msg.requestId) + msg.timestamp.encode(writer) + encodeReducerOutcome(writer, msg.result) + } + is ServerMessage.ProcedureResultMsg -> { + writer.writeSumTag(7u) + encodeProcedureStatus(writer, msg.status) + msg.timestamp.encode(writer) + msg.totalHostExecutionDuration.encode(writer) + writer.writeU32(msg.requestId) + } + } + return writer.toByteArray() + } + + private fun encodeQueryRows(writer: BsatnWriter, rows: QueryRows) { + writer.writeArrayLen(rows.tables.size) + for (t in rows.tables) { + writer.writeString(t.table) + encodeBsatnRowList(writer, t.rows) + } + } + + private fun encodeBsatnRowList(writer: BsatnWriter, rowList: BsatnRowList) { + encodeRowSizeHint(writer, rowList.sizeHint) + writer.writeU32(rowList.rowsSize.toUInt()) + val reader = rowList.rowsReader + if (rowList.rowsSize > 0) { + writer.writeRawBytes(reader.data.copyOfRange(reader.offset, reader.offset + rowList.rowsSize)) + } + } + + private fun encodeRowSizeHint(writer: BsatnWriter, hint: RowSizeHint) { + when (hint) { + is RowSizeHint.FixedSize -> { + writer.writeSumTag(0u) + writer.writeU16(hint.size) + } + is RowSizeHint.RowOffsets -> { + writer.writeSumTag(1u) + writer.writeArrayLen(hint.offsets.size) + for (o in hint.offsets) writer.writeU64(o) + } + } + } + + private fun encodeTransactionUpdate(writer: BsatnWriter, update: TransactionUpdate) { + writer.writeArrayLen(update.querySets.size) + for (qs in update.querySets) { + writer.writeU32(qs.querySetId.id) + writer.writeArrayLen(qs.tables.size) + for (tu in qs.tables) { + writer.writeString(tu.tableName) + writer.writeArrayLen(tu.rows.size) + for (tur in tu.rows) { + when (tur) { + is TableUpdateRows.PersistentTable -> { + writer.writeSumTag(0u) + encodeBsatnRowList(writer, tur.inserts) + encodeBsatnRowList(writer, tur.deletes) + } + is TableUpdateRows.EventTable -> { + writer.writeSumTag(1u) + encodeBsatnRowList(writer, tur.events) + } + } + } + } + } + } + + private fun encodeReducerOutcome(writer: BsatnWriter, outcome: ReducerOutcome) { + when (outcome) { + is ReducerOutcome.Ok -> { + writer.writeSumTag(0u) + writer.writeByteArray(outcome.retValue) + encodeTransactionUpdate(writer, outcome.transactionUpdate) + } + is ReducerOutcome.OkEmpty -> writer.writeSumTag(1u) + is ReducerOutcome.Err -> { + writer.writeSumTag(2u) + writer.writeByteArray(outcome.error) + } + is ReducerOutcome.InternalError -> { + writer.writeSumTag(3u) + writer.writeString(outcome.message) + } + } + } + + private fun encodeProcedureStatus(writer: BsatnWriter, status: ProcedureStatus) { + when (status) { + is ProcedureStatus.Returned -> { + writer.writeSumTag(0u) + writer.writeByteArray(status.value) + } + is ProcedureStatus.InternalError -> { + writer.writeSumTag(1u) + writer.writeString(status.message) + } + } + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/QueryBuilderTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/QueryBuilderTest.kt new file mode 100644 index 00000000000..ccf8f6f0743 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/QueryBuilderTest.kt @@ -0,0 +1,322 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse + +@OptIn(InternalSpacetimeApi::class) +class QueryBuilderTest { + + // ---- SqlFormat ---- + + @Test + fun `quote ident simple`() { + assertEquals("\"players\"", SqlFormat.quoteIdent("players")) + } + + @Test + fun `quote ident escapes double quotes`() { + assertEquals("\"my\"\"table\"", SqlFormat.quoteIdent("my\"table")) + } + + @Test + fun `format string literal simple`() { + assertEquals("'hello'", SqlFormat.formatStringLiteral("hello")) + } + + @Test + fun `format string literal escapes single quotes`() { + assertEquals("'it''s'", SqlFormat.formatStringLiteral("it's")) + } + + @Test + fun `format hex literal strips 0x prefix`() { + assertEquals("0xABCD", SqlFormat.formatHexLiteral("0xABCD")) + } + + @Test + fun `format hex literal without prefix`() { + assertEquals("0xABCD", SqlFormat.formatHexLiteral("ABCD")) + } + + // ---- SqlLit NaN/Infinity rejection ---- + + @Test + fun `float nan throws`() { + assertFailsWith { SqlLit.float(Float.NaN) } + } + + @Test + fun `float positive infinity throws`() { + assertFailsWith { SqlLit.float(Float.POSITIVE_INFINITY) } + } + + @Test + fun `float negative infinity throws`() { + assertFailsWith { SqlLit.float(Float.NEGATIVE_INFINITY) } + } + + @Test + fun `double nan throws`() { + assertFailsWith { SqlLit.double(Double.NaN) } + } + + @Test + fun `double positive infinity throws`() { + assertFailsWith { SqlLit.double(Double.POSITIVE_INFINITY) } + } + + @Test + fun `double negative infinity throws`() { + assertFailsWith { SqlLit.double(Double.NEGATIVE_INFINITY) } + } + + @Test + fun `finite float succeeds`() { + assertEquals("3.14", SqlLit.float(3.14f).sql) + } + + @Test + fun `finite double succeeds`() { + assertEquals("2.718", SqlLit.double(2.718).sql) + } + + @Test + fun `float scientific notation produces plain decimal`() { + val sql = SqlLit.float(1.0E-7f).sql + assertFalse(sql.contains("E", ignoreCase = true), "Expected plain decimal, got: $sql") + } + + @Test + fun `double scientific notation produces plain decimal`() { + val sql = SqlLit.double(1.0E-7).sql + assertFalse(sql.contains("E", ignoreCase = true), "Expected plain decimal, got: $sql") + } + + // ---- BoolExpr ---- + + @Test + fun `bool expr and`() { + val a = BoolExpr("a = 1") + val b = BoolExpr("b = 2") + assertEquals("(a = 1 AND b = 2)", a.and(b).sql) + } + + @Test + fun `bool expr or`() { + val a = BoolExpr("a = 1") + val b = BoolExpr("b = 2") + assertEquals("(a = 1 OR b = 2)", a.or(b).sql) + } + + @Test + fun `bool expr not`() { + val a = BoolExpr("x > 5") + assertEquals("(NOT x > 5)", a.not().sql) + } + + // ---- Col comparisons ---- + + @Test + fun `col eq literal`() { + val col = Col("t", "x") + assertEquals("(\"t\".\"x\" = 42)", col.eq(SqlLiteral("42")).sql) + } + + @Test + fun `col eq other col`() { + val a = Col("t", "x") + val b = Col("t", "y") + assertEquals("(\"t\".\"x\" = \"t\".\"y\")", a.eq(b).sql) + } + + @Test + fun `col neq`() { + val col = Col("t", "name") + assertEquals("(\"t\".\"name\" <> 'alice')", col.neq(SqlLit.string("alice")).sql) + } + + @Test + fun `col lt lte gt gte`() { + val col = Col("t", "score") + assertEquals("(\"t\".\"score\" < 10)", col.lt(SqlLit.int(10)).sql) + assertEquals("(\"t\".\"score\" <= 10)", col.lte(SqlLit.int(10)).sql) + assertEquals("(\"t\".\"score\" > 10)", col.gt(SqlLit.int(10)).sql) + assertEquals("(\"t\".\"score\" >= 10)", col.gte(SqlLit.int(10)).sql) + } + + // ---- Col convenience extensions ---- + + @Test + fun `col eq raw int`() { + val col = Col("t", "x") + assertEquals("(\"t\".\"x\" = 42)", col.eq(42).sql) + } + + @Test + fun `col eq raw string`() { + val col = Col("t", "name") + assertEquals("(\"t\".\"name\" = 'bob')", col.eq("bob").sql) + } + + @Test + fun `col eq raw bool`() { + val col = Col("t", "active") + assertEquals("(\"t\".\"active\" = TRUE)", col.eq(true).sql) + } + + // ---- IxCol join equality ---- + + @Test + fun `ix col join eq`() { + val left = IxCol("l", "id") + val right = IxCol("r", "lid") + val join = left.eq(right) + assertEquals("\"l\".\"id\"", join.leftRefSql) + assertEquals("\"r\".\"lid\"", join.rightRefSql) + } + + // ---- Table.toSql ---- + + @Test + fun `table to sql`() { + val t = Table("players", Unit, Unit) + assertEquals("SELECT * FROM \"players\"", t.toSql()) + } + + // ---- Table.where -> FromWhere ---- + + data class FakeRow(val x: Int) + class FakeCols(tableName: String) { + val health = Col(tableName, "health") + val name = Col(tableName, "name") + val active = Col(tableName, "active") + } + + @Test + fun `table where bool col`() { + val t = Table("player", FakeCols("player"), Unit) + val q = t.where { c -> c.active } + assertEquals("SELECT * FROM \"player\" WHERE (\"player\".\"active\" = TRUE)", q.toSql()) + } + + @Test + fun `table where not bool col`() { + val t = Table("player", FakeCols("player"), Unit) + val q = t.where { c -> !c.active } + assertEquals("SELECT * FROM \"player\" WHERE (NOT (\"player\".\"active\" = TRUE))", q.toSql()) + } + + @Test + fun `table where to sql`() { + val t = Table("player", FakeCols("player"), Unit) + val q = t.where { c -> c.health.gt(50) } + assertEquals("SELECT * FROM \"player\" WHERE (\"player\".\"health\" > 50)", q.toSql()) + } + + @Test + fun `from where chained and`() { + val t = Table("player", FakeCols("player"), Unit) + val q = t.where { c -> c.health.gt(50) } + .where { c -> c.name.eq("alice") } + assertEquals( + "SELECT * FROM \"player\" WHERE ((\"player\".\"health\" > 50) AND (\"player\".\"name\" = 'alice'))", + q.toSql() + ) + } + + // ---- LeftSemiJoin ---- + + data class LeftRow(val id: Int) + data class RightRow(val lid: Int) + + class LeftIxCols(tableName: String) { + val id = IxCol(tableName, "id") + val verified = IxCol(tableName, "verified") + } + class RightIxCols(tableName: String) { + val lid = IxCol(tableName, "lid") + } + + @Test + fun `left semi join to sql`() { + val left = Table("a", Unit, LeftIxCols("a")) + val right = Table("b", Unit, RightIxCols("b")) + val q = left.leftSemijoin(right) { l, r -> l.id.eq(r.lid) } + assertEquals( + "SELECT \"a\".* FROM \"a\" JOIN \"b\" ON \"a\".\"id\" = \"b\".\"lid\"", + q.toSql() + ) + } + + // ---- RightSemiJoin ---- + + @Test + fun `right semi join to sql`() { + val left = Table("a", Unit, LeftIxCols("a")) + val right = Table("b", Unit, RightIxCols("b")) + val q = left.rightSemijoin(right) { l, r -> l.id.eq(r.lid) } + assertEquals( + "SELECT \"b\".* FROM \"a\" JOIN \"b\" ON \"a\".\"id\" = \"b\".\"lid\"", + q.toSql() + ) + } + + // ---- FromWhere -> LeftSemiJoin ---- + + class LeftCols(tableName: String) { + val status = Col(tableName, "status") + } + + @Test + fun `from where left semi join to sql`() { + val left = Table("a", LeftCols("a"), LeftIxCols("a")) + val right = Table("b", Unit, RightIxCols("b")) + val q = left.where { c: LeftCols -> c.status.eq("active") } + .leftSemijoin(right) { l, r -> l.id.eq(r.lid) } + assertEquals( + "SELECT \"a\".* FROM \"a\" JOIN \"b\" ON \"a\".\"id\" = \"b\".\"lid\" WHERE (\"a\".\"status\" = 'active')", + q.toSql() + ) + } + + // ---- where with IxCol ---- + + @Test + fun `table where ix col bool`() { + val t = Table("a", LeftCols("a"), LeftIxCols("a")) + val q = t.where { _, ix -> ix.verified } + assertEquals("SELECT * FROM \"a\" WHERE (\"a\".\"verified\" = TRUE)", q.toSql()) + } + + @Test + fun `table where not ix col bool`() { + val t = Table("a", LeftCols("a"), LeftIxCols("a")) + val q = t.where { _, ix -> !ix.verified } + assertEquals("SELECT * FROM \"a\" WHERE (NOT (\"a\".\"verified\" = TRUE))", q.toSql()) + } + + // ---- SqlLit factory methods ---- + + @Test + fun `sql lit bool`() { + assertEquals("TRUE", SqlLit.bool(true).sql) + assertEquals("FALSE", SqlLit.bool(false).sql) + } + + @Test + fun `sql lit numeric types`() { + assertEquals("42", SqlLit.int(42).sql) + assertEquals("100", SqlLit.long(100L).sql) + assertEquals("7", SqlLit.byte(7).sql) + assertEquals("1000", SqlLit.short(1000).sql) + assertEquals("3.14", SqlLit.float(3.14f).sql) + assertEquals("2.718", SqlLit.double(2.718).sql) + } + + @Test + fun `sql lit string`() { + assertEquals("'hello world'", SqlLit.string("hello world").sql) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/RawFakeTransport.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/RawFakeTransport.kt new file mode 100644 index 00000000000..66c3a1f6fff --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/RawFakeTransport.kt @@ -0,0 +1,52 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ClientMessage +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ServerMessage +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.transport.SpacetimeTransport +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.transport.Transport +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.persistentListOf +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow + +/** + * A test transport that accepts raw byte arrays and decodes BSATN inside the + * [incoming] flow, mirroring [SpacetimeTransport]'s + * decode-in-flow behavior. + * + * This allows testing how [DbConnection] reacts to malformed frames: + * truncated BSATN, invalid sum tags, empty frames, etc. + * Decode errors surface as exceptions in the flow, which DbConnection's + * receive loop catches and routes to onDisconnect(error). + */ +internal class RawFakeTransport : Transport { + private val _rawIncoming = Channel(Channel.UNLIMITED) + private val _sent = atomic(persistentListOf()) + private var _connected = false + + override suspend fun connect() { + _connected = true + } + + override suspend fun send(message: ClientMessage) { + _sent.update { it.add(message) } + } + + override fun incoming(): Flow = flow { + for (bytes in _rawIncoming) { + emit(ServerMessage.decodeFromBytes(bytes)) + } + } + + override suspend fun disconnect() { + _connected = false + _rawIncoming.close() + } + + /** Send raw BSATN bytes to the client. Decode happens inside [incoming]. */ + suspend fun sendRawToClient(bytes: ByteArray) { + _rawIncoming.send(bytes) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ReducerAndQueryEdgeCaseTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ReducerAndQueryEdgeCaseTest.kt new file mode 100644 index 00000000000..20b1a7139dd --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ReducerAndQueryEdgeCaseTest.kt @@ -0,0 +1,497 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import kotlinx.coroutines.CoroutineExceptionHandler +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) +class ReducerAndQueryEdgeCaseTest { + + // ========================================================================= + // One-Off Query Edge Cases + // ========================================================================= + + @Test + fun `multiple one off queries concurrently`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var result1: OneOffQueryResult? = null + var result2: OneOffQueryResult? = null + var result3: OneOffQueryResult? = null + val id1 = conn.oneOffQuery("SELECT 1") { result1 = it } + val id2 = conn.oneOffQuery("SELECT 2") { result2 = it } + val id3 = conn.oneOffQuery("SELECT 3") { result3 = it } + advanceUntilIdle() + + // Respond in reverse order + transport.sendToClient( + ServerMessage.OneOffQueryResult(requestId = id3, result = QueryResult.Ok(emptyQueryRows())) + ) + transport.sendToClient( + ServerMessage.OneOffQueryResult(requestId = id1, result = QueryResult.Ok(emptyQueryRows())) + ) + transport.sendToClient( + ServerMessage.OneOffQueryResult(requestId = id2, result = QueryResult.Err("error")) + ) + advanceUntilIdle() + + assertNotNull(result1) + assertNotNull(result2) + assertNotNull(result3) + assertTrue(result1 is SdkResult.Success) + assertTrue(result2 is SdkResult.Failure) + assertTrue(result3 is SdkResult.Success) + conn.disconnect() + } + + @Test + fun `one off query callback is removed after firing`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var callCount = 0 + val id = conn.oneOffQuery("SELECT 1") { callCount++ } + advanceUntilIdle() + + // Send result twice with same requestId + transport.sendToClient( + ServerMessage.OneOffQueryResult(requestId = id, result = QueryResult.Ok(emptyQueryRows())) + ) + advanceUntilIdle() + transport.sendToClient( + ServerMessage.OneOffQueryResult(requestId = id, result = QueryResult.Ok(emptyQueryRows())) + ) + advanceUntilIdle() + + assertEquals(1, callCount) // Should only fire once + conn.disconnect() + } + + // ========================================================================= + // Reducer Edge Cases + // ========================================================================= + + @Test + fun `reducer callback is removed after firing`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var callCount = 0 + val id = conn.callReducer("add", byteArrayOf(), "args", callback = { callCount++ }) + advanceUntilIdle() + + // Send result twice + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = id, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + advanceUntilIdle() + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = id, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + advanceUntilIdle() + + assertEquals(1, callCount) // Should only fire once + conn.disconnect() + } + + @Test + fun `reducer result ok with table updates mutates cache`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Subscribe first to establish the table + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + + // Call reducer + var status: Status? = null + val id = conn.callReducer("add", byteArrayOf(), "args", callback = { ctx -> status = ctx.status }) + advanceUntilIdle() + + // Reducer result with table insert + val row = SampleRow(1, "FromReducer") + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = id, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.Ok( + retValue = byteArrayOf(), + transactionUpdate = TransactionUpdate( + listOf( + QuerySetUpdate( + handle.querySetId, + listOf( + TableUpdate( + "sample", + listOf( + TableUpdateRows.PersistentTable( + inserts = buildRowList(row.encode()), + deletes = buildRowList(), + ) + ) + ) + ) + ) + ) + ), + ), + ) + ) + advanceUntilIdle() + + assertEquals(Status.Committed, status) + assertEquals(1, cache.count()) + assertEquals(row, cache.all().single()) + conn.disconnect() + } + + @Test + fun `reducer result with empty error bytes`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var status: Status? = null + val id = conn.callReducer("bad", byteArrayOf(), "args", callback = { ctx -> status = ctx.status }) + advanceUntilIdle() + + // Empty error bytes + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = id, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.Err(byteArrayOf()), + ) + ) + advanceUntilIdle() + + assertTrue(status is Status.Failed) + assertTrue((status as Status.Failed).message.contains("undecodable")) + conn.disconnect() + } + + // ========================================================================= + // Multi-Table Transaction Processing + // ========================================================================= + + @Test + fun `transaction update across multiple tables`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + + val cacheA = createSampleCache() + val cacheB = createSampleCache() + conn.clientCache.register("table_a", cacheA) + conn.clientCache.register("table_b", cacheB) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM table_a", "SELECT * FROM table_b")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + + // Transaction inserting into both tables + val rowA = SampleRow(1, "A") + val rowB = SampleRow(2, "B") + transport.sendToClient( + ServerMessage.TransactionUpdateMsg( + TransactionUpdate( + listOf( + QuerySetUpdate( + handle.querySetId, + listOf( + TableUpdate( + "table_a", + listOf( + TableUpdateRows.PersistentTable( + inserts = buildRowList(rowA.encode()), + deletes = buildRowList(), + ) + ) + ), + TableUpdate( + "table_b", + listOf( + TableUpdateRows.PersistentTable( + inserts = buildRowList(rowB.encode()), + deletes = buildRowList(), + ) + ) + ), + ) + ) + ) + ) + ) + ) + advanceUntilIdle() + + assertEquals(1, cacheA.count()) + assertEquals(1, cacheB.count()) + assertEquals(rowA, cacheA.all().single()) + assertEquals(rowB, cacheB.all().single()) + conn.disconnect() + } + + @Test + fun `transaction update with unknown table is skipped`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + val cache = createSampleCache() + conn.clientCache.register("known", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM known")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + + // Transaction with both known and unknown tables + transport.sendToClient( + ServerMessage.TransactionUpdateMsg( + TransactionUpdate( + listOf( + QuerySetUpdate( + handle.querySetId, + listOf( + TableUpdate( + "unknown", + listOf( + TableUpdateRows.PersistentTable( + inserts = buildRowList(SampleRow(1, "ghost").encode()), + deletes = buildRowList(), + ) + ) + ), + TableUpdate( + "known", + listOf( + TableUpdateRows.PersistentTable( + inserts = buildRowList(SampleRow(2, "visible").encode()), + deletes = buildRowList(), + ) + ) + ), + ) + ) + ) + ) + ) + ) + advanceUntilIdle() + + // Known table gets the insert; unknown table is skipped without error + assertEquals(1, cache.count()) + assertEquals("visible", cache.all().single().name) + assertTrue(conn.isActive) + conn.disconnect() + } + + // ========================================================================= + // Concurrent Reducer Calls + // ========================================================================= + + @Test + fun `multiple concurrent reducer calls get correct callbacks`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val results = mutableMapOf() + val id1 = conn.callReducer("add", byteArrayOf(1), "add_args", callback = { ctx -> + results["add"] = ctx.status + }) + val id2 = conn.callReducer("remove", byteArrayOf(2), "remove_args", callback = { ctx -> + results["remove"] = ctx.status + }) + val id3 = conn.callReducer("update", byteArrayOf(3), "update_args", callback = { ctx -> + results["update"] = ctx.status + }) + advanceUntilIdle() + + // Respond in reverse order + val writer = BsatnWriter() + writer.writeString("update failed") + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = id3, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.Err(writer.toByteArray()), + ) + ) + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = id1, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = id2, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + advanceUntilIdle() + + assertEquals(3, results.size) + assertEquals(Status.Committed, results["add"]) + assertEquals(Status.Committed, results["remove"]) + assertTrue(results["update"] is Status.Failed) + conn.disconnect() + } + + // ========================================================================= + // Content-Based Keying (Tables Without Primary Keys) + // ========================================================================= + + @Test + fun `content keyed cache insert and delete`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + + val row1 = SampleRow(1, "Alice") + val row2 = SampleRow(2, "Bob") + cache.applyInserts(STUB_CTX, buildRowList(row1.encode(), row2.encode())) + + assertEquals(2, cache.count()) + assertTrue(cache.all().containsAll(listOf(row1, row2))) + + // Delete row1 by content + val parsed = cache.parseDeletes(buildRowList(row1.encode())) + cache.applyDeletes(STUB_CTX, parsed) + + assertEquals(1, cache.count()) + assertEquals(row2, cache.all().single()) + } + + @Test + fun `content keyed cache duplicate insert increments ref count`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + + val row = SampleRow(1, "Alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + assertEquals(1, cache.count()) // One unique row, ref count = 2 + + // First delete decrements ref count + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed) + assertEquals(1, cache.count()) // Still present + + // Second delete removes it + val parsed2 = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed2) + assertEquals(0, cache.count()) + } + + @Test + fun `content keyed cache update by content`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + + val oldRow = SampleRow(1, "Alice") + cache.applyInserts(STUB_CTX, buildRowList(oldRow.encode())) + + // An update with same content in delete + different content in insert + // For content-keyed tables, the "update" detection is by key, + // and since keys are content-based, this is a delete+insert, not an update + val newRow = SampleRow(1, "Alice Updated") + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(newRow.encode()), + deletes = buildRowList(oldRow.encode()), + ) + val parsed = cache.parseUpdate(update) + cache.applyUpdate(STUB_CTX, parsed) + + assertEquals(1, cache.count()) + assertEquals(newRow, cache.all().single()) + } + + // ========================================================================= + // Event Table Behavior + // ========================================================================= + + @Test + fun `event table does not store rows but fires callbacks`() { + val cache = createSampleCache() + val events = mutableListOf() + cache.onInsert { _, row -> events.add(row) } + + val row1 = SampleRow(1, "Alice") + val row2 = SampleRow(2, "Bob") + val eventUpdate = TableUpdateRows.EventTable( + events = buildRowList(row1.encode(), row2.encode()) + ) + val parsed = cache.parseUpdate(eventUpdate) + val callbacks = cache.applyUpdate(STUB_CTX, parsed) + for (cb in callbacks) cb.invoke() + + assertEquals(0, cache.count()) // Not stored + assertEquals(listOf(row1, row2), events) // Callbacks fired + } + + @Test + fun `event table does not fire on before delete`() { + val cache = createSampleCache() + var beforeDeleteFired = false + cache.onBeforeDelete { _, _ -> beforeDeleteFired = true } + + val eventUpdate = TableUpdateRows.EventTable( + events = buildRowList(SampleRow(1, "Alice").encode()) + ) + val parsed = cache.parseUpdate(eventUpdate) + cache.preApplyUpdate(STUB_CTX, parsed) + cache.applyUpdate(STUB_CTX, parsed) + + assertFalse(beforeDeleteFired) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ReducerIntegrationTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ReducerIntegrationTest.kt new file mode 100644 index 00000000000..c4fa317db4f --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ReducerIntegrationTest.kt @@ -0,0 +1,563 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) +class ReducerIntegrationTest { + + // --- Reducers --- + + @Test + fun `call reducer sends client message`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.callReducer("add", byteArrayOf(1, 2, 3), "test-args") + advanceUntilIdle() + + val reducerMsg = transport.sentMessages.filterIsInstance().firstOrNull() + assertNotNull(reducerMsg) + assertEquals("add", reducerMsg.reducer) + assertTrue(reducerMsg.args.contentEquals(byteArrayOf(1, 2, 3))) + conn.disconnect() + } + + @Test + fun `reducer result ok fires callback with committed`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var status: Status? = null + val requestId = conn.callReducer( + reducerName = "add", + encodedArgs = byteArrayOf(), + typedArgs = "args", + callback = { ctx -> status = ctx.status }, + ) + advanceUntilIdle() + + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = requestId, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.Ok( + retValue = byteArrayOf(), + transactionUpdate = TransactionUpdate(emptyList()), + ), + ) + ) + advanceUntilIdle() + + assertEquals(Status.Committed, status) + conn.disconnect() + } + + @Test + fun `reducer result err fires callback with failed`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var status: Status? = null + val errorText = "something went wrong" + val writer = BsatnWriter() + writer.writeString(errorText) + val errorBytes = writer.toByteArray() + + val requestId = conn.callReducer( + reducerName = "bad_reducer", + encodedArgs = byteArrayOf(), + typedArgs = "args", + callback = { ctx -> status = ctx.status }, + ) + advanceUntilIdle() + + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = requestId, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.Err(errorBytes), + ) + ) + advanceUntilIdle() + + assertTrue(status is Status.Failed) + assertEquals(errorText, (status as Status.Failed).message) + conn.disconnect() + } + + // --- Reducer outcomes --- + + @Test + fun `reducer result ok empty fires callback with committed`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var status: Status? = null + val requestId = conn.callReducer( + reducerName = "noop", + encodedArgs = byteArrayOf(), + typedArgs = "args", + callback = { ctx -> status = ctx.status }, + ) + advanceUntilIdle() + + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = requestId, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + advanceUntilIdle() + + assertEquals(Status.Committed, status) + conn.disconnect() + } + + @Test + fun `reducer result internal error fires callback with failed`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var status: Status? = null + val requestId = conn.callReducer( + reducerName = "broken", + encodedArgs = byteArrayOf(), + typedArgs = "args", + callback = { ctx -> status = ctx.status }, + ) + advanceUntilIdle() + + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = requestId, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.InternalError("internal server error"), + ) + ) + advanceUntilIdle() + + assertTrue(status is Status.Failed) + assertEquals("internal server error", (status as Status.Failed).message) + conn.disconnect() + } + + // --- callReducer without callback (fire-and-forget) --- + + @Test + fun `call reducer without callback sends message`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.callReducer("add", byteArrayOf(), "args", callback = null) + advanceUntilIdle() + + val sent = transport.sentMessages.filterIsInstance() + assertEquals(1, sent.size) + assertEquals("add", sent[0].reducer) + + // Sending a result for it should not crash (no callback registered) + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = sent[0].requestId, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + advanceUntilIdle() + + assertTrue(conn.isActive) + conn.disconnect() + } + + // --- Reducer result before identity is set --- + + @Test + fun `reducer result before identity set is ignored`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + // Do NOT send InitialConnection — identity stays null + + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = 1u, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + advanceUntilIdle() + + // Connection should still be active (message silently ignored) + assertTrue(conn.isActive) + conn.disconnect() + } + + @Test + fun `reducer result before identity cleans up call info and callbacks`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + // Do NOT send InitialConnection — identity stays null + + // Manually inject a pending reducer result as if the server responded + // before InitialConnection arrived. The requestId=1u won't match a real + // callReducer (which requires Connected + identity), but the cleanup + // path must still remove any stale entries and finish tracking. + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = 1u, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + advanceUntilIdle() + + // The stats tracker should have finished tracking (not leaked) + assertEquals(0, conn.stats.reducerRequestTracker.requestsAwaitingResponse) + assertTrue(conn.isActive) + conn.disconnect() + } + + // --- decodeReducerError with corrupted BSATN --- + + @Test + fun `reducer err with corrupted bsatn does not crash`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var status: Status? = null + conn.callReducer("bad", byteArrayOf(), "args", callback = { ctx -> + status = ctx.status + }) + advanceUntilIdle() + + val sent = transport.sentMessages.filterIsInstance().last() + // Send Err with invalid BSATN bytes (not a valid BSATN string) + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = sent.requestId, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.Err(byteArrayOf(0xFF.toByte(), 0x00, 0x01)), + ) + ) + advanceUntilIdle() + + val capturedStatus = status + assertNotNull(capturedStatus) + assertTrue(capturedStatus is Status.Failed) + assertTrue(capturedStatus.message.contains("undecodable")) + conn.disconnect() + } + + // --- Reducer timeout and burst scenarios --- + + @Test + fun `pending reducer callbacks cleared on disconnect never fire`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var callbackFired = false + conn.callReducer("slow", byteArrayOf(), "args", callback = { _ -> + callbackFired = true + }) + advanceUntilIdle() + + // Verify the reducer is pending + assertEquals(1, conn.stats.reducerRequestTracker.requestsAwaitingResponse) + + // Disconnect before the server responds — simulates a "timeout" scenario + conn.disconnect() + advanceUntilIdle() + + assertFalse(callbackFired, "Reducer callback must not fire after disconnect") + } + + @Test + fun `burst reducer calls all get unique request ids`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val count = 100 + val requestIds = mutableSetOf() + val results = mutableMapOf() + + // Fire 100 reducer calls in a burst + repeat(count) { i -> + val id = conn.callReducer("op", byteArrayOf(i.toByte()), "args-$i", callback = { ctx -> + results[i.toUInt()] = ctx.status + }) + requestIds.add(id) + } + advanceUntilIdle() + + // All IDs must be unique + assertEquals(count, requestIds.size) + assertEquals(count, conn.stats.reducerRequestTracker.requestsAwaitingResponse) + + // Respond to all in order + for (id in requestIds) { + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = id, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + } + advanceUntilIdle() + + assertEquals(0, conn.stats.reducerRequestTracker.requestsAwaitingResponse) + assertEquals(count, conn.stats.reducerRequestTracker.sampleCount) + conn.disconnect() + } + + @Test + fun `burst reducer calls responded out of order`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val count = 50 + val callbacks = mutableMapOf() + val requestIds = mutableListOf() + + repeat(count) { i -> + val id = conn.callReducer("op-$i", byteArrayOf(i.toByte()), "args-$i", callback = { ctx -> + callbacks[i.toUInt()] = ctx.status + }) + requestIds.add(id) + } + advanceUntilIdle() + + // Respond in reverse order + for (id in requestIds.reversed()) { + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = id, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + } + advanceUntilIdle() + + assertEquals(0, conn.stats.reducerRequestTracker.requestsAwaitingResponse) + conn.disconnect() + } + + @Test + fun `reducer result after disconnect is dropped`() = runTest { + val transport = FakeTransport() + var callbackFired = false + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.callReducer("op", byteArrayOf(), "args", callback = { _ -> + callbackFired = true + }) + advanceUntilIdle() + + // Server closes the connection + transport.closeFromServer() + advanceUntilIdle() + assertFalse(conn.isActive) + + // Callback was cleared by failPendingOperations, never fires + assertFalse(callbackFired) + } + + @Test + fun `reducer with table mutations and callback both fire`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + + var reducerStatus: Status? = null + val insertedRows = mutableListOf() + cache.onInsert { _, row -> insertedRows.add(row) } + + val row1 = SampleRow(1, "Alice") + val row2 = SampleRow(2, "Bob") + + val requestId = conn.callReducer("add_two", byteArrayOf(), "args", callback = { ctx -> + reducerStatus = ctx.status + }) + advanceUntilIdle() + + // Reducer result inserts two rows in a single transaction + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = requestId, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.Ok( + retValue = byteArrayOf(), + transactionUpdate = TransactionUpdate( + listOf( + QuerySetUpdate( + handle.querySetId, + listOf( + TableUpdate( + "sample", + listOf( + TableUpdateRows.PersistentTable( + inserts = buildRowList(row1.encode(), row2.encode()), + deletes = buildRowList(), + ) + ) + ) + ) + ) + ) + ), + ), + ) + ) + advanceUntilIdle() + + // Both callbacks must have fired + assertEquals(Status.Committed, reducerStatus) + assertEquals(2, insertedRows.size) + assertEquals(2, cache.count()) + conn.disconnect() + } + + @Test + fun `many pending reducers all cleared on disconnect`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var firedCount = 0 + repeat(50) { + conn.callReducer("op", byteArrayOf(), "args", callback = { _ -> firedCount++ }) + } + advanceUntilIdle() + + assertEquals(50, conn.stats.reducerRequestTracker.requestsAwaitingResponse) + + conn.disconnect() + advanceUntilIdle() + + assertEquals(0, firedCount, "No reducer callbacks should fire after disconnect") + } + + @Test + fun `mixed reducer outcomes in burst`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val results = mutableMapOf() + + val id1 = conn.callReducer("ok1", byteArrayOf(), "ok1", callback = { ctx -> + results["ok1"] = ctx.status + }) + val id2 = conn.callReducer("err", byteArrayOf(), "err", callback = { ctx -> + results["err"] = ctx.status + }) + val id3 = conn.callReducer("ok2", byteArrayOf(), "ok2", callback = { ctx -> + results["ok2"] = ctx.status + }) + val id4 = conn.callReducer("internal_err", byteArrayOf(), "internal_err", callback = { ctx -> + results["internal_err"] = ctx.status + }) + advanceUntilIdle() + + val errWriter = BsatnWriter() + errWriter.writeString("bad input") + + // Send all results at once — mixed outcomes + transport.sendToClient(ServerMessage.ReducerResultMsg(id1, Timestamp.UNIX_EPOCH, ReducerOutcome.OkEmpty)) + transport.sendToClient(ServerMessage.ReducerResultMsg(id2, Timestamp.UNIX_EPOCH, ReducerOutcome.Err(errWriter.toByteArray()))) + transport.sendToClient(ServerMessage.ReducerResultMsg(id3, Timestamp.UNIX_EPOCH, ReducerOutcome.OkEmpty)) + transport.sendToClient(ServerMessage.ReducerResultMsg(id4, Timestamp.UNIX_EPOCH, ReducerOutcome.InternalError("server crash"))) + advanceUntilIdle() + + assertEquals(4, results.size) + assertEquals(Status.Committed, results["ok1"]) + assertEquals(Status.Committed, results["ok2"]) + assertTrue(results["err"] is Status.Failed) + assertEquals("bad input", (results["err"] as Status.Failed).message) + assertTrue(results["internal_err"] is Status.Failed) + assertEquals("server crash", (results["internal_err"] as Status.Failed).message) + conn.disconnect() + } + + // --- typedArgs round-trip through ReducerCallInfo --- + + @Test + fun `call reducer typed args round trip through call info`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + data class MyArgs(val x: Int, val y: String) + val original = MyArgs(42, "hello") + var receivedArgs: MyArgs? = null + val requestId = conn.callReducer( + reducerName = "typed_op", + encodedArgs = byteArrayOf(), + typedArgs = original, + callback = { ctx -> receivedArgs = ctx.args }, + ) + advanceUntilIdle() + + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = requestId, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + advanceUntilIdle() + + // The typed args must survive the round-trip through ReducerCallInfo(Any) + // back to EventContext.Reducer.args without corruption. + assertEquals(original, receivedArgs) + conn.disconnect() + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ServerMessageTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ServerMessageTest.kt new file mode 100644 index 00000000000..af398d2c356 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ServerMessageTest.kt @@ -0,0 +1,328 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ProcedureStatus +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QueryResult +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ReducerOutcome +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ServerMessage +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.TransactionUpdate +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertIs +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class ServerMessageTest { + + /** Writes an Identity (U256 = 32 bytes LE) */ + private fun BsatnWriter.writeIdentity(value: BigInteger) = writeU256(value) + + /** Writes a ConnectionId (U128 = 16 bytes LE) */ + private fun BsatnWriter.writeConnectionId(value: BigInteger) = writeU128(value) + + /** Writes a Timestamp (I64 microseconds) */ + private fun BsatnWriter.writeTimestamp(micros: Long) = writeI64(micros) + + /** Writes a TimeDuration (I64 microseconds) */ + private fun BsatnWriter.writeTimeDuration(micros: Long) = writeI64(micros) + + /** Writes an empty QueryRows (array len = 0) */ + private fun BsatnWriter.writeEmptyQueryRows() = writeArrayLen(0) + + /** Writes an empty TransactionUpdate (array len = 0 querySets) */ + private fun BsatnWriter.writeEmptyTransactionUpdate() = writeArrayLen(0) + + // ---- InitialConnection (tag 0) ---- + + @Test + fun `initial connection decode`() { + val identityValue = BigInteger.parseString("12345678", 16) + val connIdValue = BigInteger.parseString("ABCD", 16) + + val writer = BsatnWriter() + writer.writeSumTag(0u) // tag = InitialConnection + writer.writeIdentity(identityValue) + writer.writeConnectionId(connIdValue) + writer.writeString("my-auth-token") + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + assertEquals(Identity(identityValue), msg.identity) + assertEquals(ConnectionId(connIdValue), msg.connectionId) + assertEquals("my-auth-token", msg.token) + } + + // ---- SubscribeApplied (tag 1) ---- + + @Test + fun `subscribe applied empty rows`() { + val writer = BsatnWriter() + writer.writeSumTag(1u) // tag = SubscribeApplied + writer.writeU32(42u) // requestId + writer.writeU32(7u) // querySetId + writer.writeEmptyQueryRows() + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + assertEquals(42u, msg.requestId) + assertEquals(7u, msg.querySetId.id) + assertTrue(msg.rows.tables.isEmpty()) + } + + // ---- UnsubscribeApplied (tag 2) ---- + + @Test + fun `unsubscribe applied with rows`() { + val writer = BsatnWriter() + writer.writeSumTag(2u) // tag = UnsubscribeApplied + writer.writeU32(10u) // requestId + writer.writeU32(3u) // querySetId + writer.writeSumTag(0u) // Option::Some + writer.writeEmptyQueryRows() + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + assertEquals(10u, msg.requestId) + assertEquals(3u, msg.querySetId.id) + assertNotNull(msg.rows) + } + + @Test + fun `unsubscribe applied without rows`() { + val writer = BsatnWriter() + writer.writeSumTag(2u) // tag = UnsubscribeApplied + writer.writeU32(10u) // requestId + writer.writeU32(3u) // querySetId + writer.writeSumTag(1u) // Option::None + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + assertNull(msg.rows) + } + + // ---- SubscriptionError (tag 3) ---- + + @Test + fun `subscription error with request id`() { + val writer = BsatnWriter() + writer.writeSumTag(3u) // tag = SubscriptionError + writer.writeSumTag(0u) // Option::Some(requestId) + writer.writeU32(55u) // requestId + writer.writeU32(8u) // querySetId + writer.writeString("table not found") + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + assertEquals(55u, msg.requestId) + assertEquals(8u, msg.querySetId.id) + assertEquals("table not found", msg.error) + } + + @Test + fun `subscription error without request id`() { + val writer = BsatnWriter() + writer.writeSumTag(3u) // tag = SubscriptionError + writer.writeSumTag(1u) // Option::None + writer.writeU32(8u) // querySetId + writer.writeString("internal error") + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + assertNull(msg.requestId) + assertEquals("internal error", msg.error) + } + + // ---- TransactionUpdateMsg (tag 4) ---- + + @Test + fun `transaction update empty query sets`() { + val writer = BsatnWriter() + writer.writeSumTag(4u) // tag = TransactionUpdateMsg + writer.writeEmptyTransactionUpdate() + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + assertTrue(msg.update.querySets.isEmpty()) + } + + // ---- OneOffQueryResult (tag 5) ---- + + @Test + fun `one off query result ok`() { + val writer = BsatnWriter() + writer.writeSumTag(5u) // tag = OneOffQueryResult + writer.writeU32(100u) // requestId + writer.writeSumTag(0u) // Result::Ok + writer.writeEmptyQueryRows() + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + assertEquals(100u, msg.requestId) + assertIs(msg.result) + } + + @Test + fun `one off query result err`() { + val writer = BsatnWriter() + writer.writeSumTag(5u) // tag = OneOffQueryResult + writer.writeU32(100u) // requestId + writer.writeSumTag(1u) // Result::Err + writer.writeString("syntax error in query") + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + assertEquals(100u, msg.requestId) + val err = assertIs(msg.result) + assertEquals("syntax error in query", err.error) + } + + // ---- ReducerResultMsg (tag 6) ---- + + @Test + fun `reducer result ok`() { + val writer = BsatnWriter() + writer.writeSumTag(6u) // tag = ReducerResultMsg + writer.writeU32(20u) // requestId + writer.writeTimestamp(1_000_000L) // timestamp + writer.writeSumTag(0u) // ReducerOutcome::Ok + writer.writeByteArray(byteArrayOf()) // retValue (empty) + writer.writeEmptyTransactionUpdate() + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + assertEquals(20u, msg.requestId) + val ok = assertIs(msg.result) + assertTrue(ok.retValue.isEmpty()) + assertTrue(ok.transactionUpdate.querySets.isEmpty()) + } + + @Test + fun `reducer result ok empty`() { + val writer = BsatnWriter() + writer.writeSumTag(6u) // tag = ReducerResultMsg + writer.writeU32(21u) // requestId + writer.writeTimestamp(2_000_000L) + writer.writeSumTag(1u) // ReducerOutcome::OkEmpty + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + assertIs(msg.result) + } + + @Test + fun `reducer result err`() { + val writer = BsatnWriter() + writer.writeSumTag(6u) // tag = ReducerResultMsg + writer.writeU32(22u) // requestId + writer.writeTimestamp(3_000_000L) + writer.writeSumTag(2u) // ReducerOutcome::Err + writer.writeByteArray(byteArrayOf(0xDE.toByte(), 0xAD.toByte())) + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + val err = assertIs(msg.result) + assertTrue(err.error.contentEquals(byteArrayOf(0xDE.toByte(), 0xAD.toByte()))) + } + + @Test + fun `reducer result internal error`() { + val writer = BsatnWriter() + writer.writeSumTag(6u) // tag = ReducerResultMsg + writer.writeU32(23u) // requestId + writer.writeTimestamp(4_000_000L) + writer.writeSumTag(3u) // ReducerOutcome::InternalError + writer.writeString("out of memory") + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + val err = assertIs(msg.result) + assertEquals("out of memory", err.message) + } + + // ---- ProcedureResultMsg (tag 7) ---- + + @Test + fun `procedure result returned`() { + val writer = BsatnWriter() + writer.writeSumTag(7u) // tag = ProcedureResultMsg + writer.writeSumTag(0u) // ProcedureStatus::Returned + writer.writeByteArray(byteArrayOf(42)) // return value + writer.writeTimestamp(5_000_000L) + writer.writeTimeDuration(100_000L) // 100ms + writer.writeU32(50u) // requestId + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + assertEquals(50u, msg.requestId) + val returned = assertIs(msg.status) + assertTrue(returned.value.contentEquals(byteArrayOf(42))) + } + + @Test + fun `procedure result internal error`() { + val writer = BsatnWriter() + writer.writeSumTag(7u) // tag = ProcedureResultMsg + writer.writeSumTag(1u) // ProcedureStatus::InternalError + writer.writeString("procedure failed") + writer.writeTimestamp(6_000_000L) + writer.writeTimeDuration(200_000L) + writer.writeU32(51u) // requestId + + val msg = ServerMessage.decodeFromBytes(writer.toByteArray()) + assertIs(msg) + assertEquals(51u, msg.requestId) + val err = assertIs(msg.status) + assertEquals("procedure failed", err.message) + } + + // ---- Unknown tag ---- + + @Test + fun `unknown tag throws`() { + val writer = BsatnWriter() + writer.writeSumTag(255u) // invalid tag + + assertFailsWith { + ServerMessage.decodeFromBytes(writer.toByteArray()) + } + } + + // ---- ReducerOutcome equality ---- + + @Test + fun `reducer outcome ok equality`() { + val a = ReducerOutcome.Ok(byteArrayOf(1, 2), TransactionUpdate(emptyList())) + val b = ReducerOutcome.Ok(byteArrayOf(1, 2), TransactionUpdate(emptyList())) + val c = ReducerOutcome.Ok(byteArrayOf(3, 4), TransactionUpdate(emptyList())) + + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + assertTrue(a != c) + } + + @Test + fun `reducer outcome err equality`() { + val a = ReducerOutcome.Err(byteArrayOf(1, 2)) + val b = ReducerOutcome.Err(byteArrayOf(1, 2)) + val c = ReducerOutcome.Err(byteArrayOf(3, 4)) + + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + assertTrue(a != c) + } + + @Test + fun `procedure status returned equality`() { + val a = ProcedureStatus.Returned(byteArrayOf(10)) + val b = ProcedureStatus.Returned(byteArrayOf(10)) + val c = ProcedureStatus.Returned(byteArrayOf(20)) + + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + assertTrue(a != c) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/StatsIntegrationTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/StatsIntegrationTest.kt new file mode 100644 index 00000000000..d283661ab82 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/StatsIntegrationTest.kt @@ -0,0 +1,167 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.TimeDuration +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.time.Duration + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) +class StatsIntegrationTest { + + // --- Stats tracking --- + + @Test + fun `stats subscription tracker increments on subscribe applied`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val tracker = conn.stats.subscriptionRequestTracker + assertEquals(0, tracker.sampleCount) + + val handle = conn.subscribe(listOf("SELECT * FROM player")) + // Request started but not yet finished + assertEquals(1, tracker.requestsAwaitingResponse) + + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + + assertEquals(1, tracker.sampleCount) + assertEquals(0, tracker.requestsAwaitingResponse) + conn.disconnect() + } + + @Test + fun `stats reducer tracker increments on reducer result`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val tracker = conn.stats.reducerRequestTracker + assertEquals(0, tracker.sampleCount) + + val requestId = conn.callReducer("add", byteArrayOf(), "args", callback = null) + advanceUntilIdle() + assertEquals(1, tracker.requestsAwaitingResponse) + + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = requestId, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + advanceUntilIdle() + + assertEquals(1, tracker.sampleCount) + assertEquals(0, tracker.requestsAwaitingResponse) + conn.disconnect() + } + + @Test + fun `stats procedure tracker increments on procedure result`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val tracker = conn.stats.procedureRequestTracker + assertEquals(0, tracker.sampleCount) + + val requestId = conn.callProcedure("my_proc", byteArrayOf(), callback = null) + advanceUntilIdle() + assertEquals(1, tracker.requestsAwaitingResponse) + + transport.sendToClient( + ServerMessage.ProcedureResultMsg( + requestId = requestId, + timestamp = Timestamp.UNIX_EPOCH, + status = ProcedureStatus.Returned(byteArrayOf()), + totalHostExecutionDuration = TimeDuration(Duration.ZERO), + ) + ) + advanceUntilIdle() + + assertEquals(1, tracker.sampleCount) + assertEquals(0, tracker.requestsAwaitingResponse) + conn.disconnect() + } + + @Test + fun `stats one off tracker increments on query result`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val tracker = conn.stats.oneOffRequestTracker + assertEquals(0, tracker.sampleCount) + + val requestId = conn.oneOffQuery("SELECT 1") { _ -> } + advanceUntilIdle() + assertEquals(1, tracker.requestsAwaitingResponse) + + transport.sendToClient( + ServerMessage.OneOffQueryResult( + requestId = requestId, + result = QueryResult.Ok(emptyQueryRows()), + ) + ) + advanceUntilIdle() + + assertEquals(1, tracker.sampleCount) + assertEquals(0, tracker.requestsAwaitingResponse) + conn.disconnect() + } + + @Test + fun `stats apply message tracker increments on every server message`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val tracker = conn.stats.applyMessageTracker + // InitialConnection is the first message processed + assertEquals(1, tracker.sampleCount) + + // Send a SubscribeApplied — second message + val handle = conn.subscribe(listOf("SELECT * FROM player")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + assertEquals(2, tracker.sampleCount) + + // Send a ReducerResult — third message + val reducerRequestId = conn.callReducer("add", byteArrayOf(), "args", callback = null) + advanceUntilIdle() + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = reducerRequestId, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + advanceUntilIdle() + assertEquals(3, tracker.sampleCount) + + conn.disconnect() + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/StatsTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/StatsTest.kt new file mode 100644 index 00000000000..c53aa32429d --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/StatsTest.kt @@ -0,0 +1,324 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds +import kotlin.time.TestTimeSource + +class StatsTest { + + // ---- Start / finish tracking ---- + + @Test + fun `start and finish returns true`() { + val tracker = NetworkRequestTracker() + val id = tracker.startTrackingRequest("test") + assertTrue(tracker.finishTrackingRequest(id)) + } + + @Test + fun `finish unknown id returns false`() { + val tracker = NetworkRequestTracker() + assertFalse(tracker.finishTrackingRequest(999u)) + } + + @Test + fun `sample count increments after finish`() { + val tracker = NetworkRequestTracker() + assertEquals(0, tracker.sampleCount) + + val id = tracker.startTrackingRequest() + tracker.finishTrackingRequest(id) + + assertEquals(1, tracker.sampleCount) + } + + @Test + fun `requests awaiting response tracks active requests`() { + val tracker = NetworkRequestTracker() + assertEquals(0, tracker.requestsAwaitingResponse) + + val id1 = tracker.startTrackingRequest() + val id2 = tracker.startTrackingRequest() + assertEquals(2, tracker.requestsAwaitingResponse) + + tracker.finishTrackingRequest(id1) + assertEquals(1, tracker.requestsAwaitingResponse) + + tracker.finishTrackingRequest(id2) + assertEquals(0, tracker.requestsAwaitingResponse) + } + + // ---- All-time min/max ---- + + @Test + fun `all time min max tracks extremes`() { + val tracker = NetworkRequestTracker() + assertNull(tracker.allTimeMinMax) + + tracker.insertSample(100.milliseconds, "fast") + tracker.insertSample(500.milliseconds, "slow") + tracker.insertSample(200.milliseconds, "medium") + + val result = assertNotNull(tracker.allTimeMinMax) + assertEquals(100.milliseconds, result.min.duration) + assertEquals("fast", result.min.metadata) + assertEquals(500.milliseconds, result.max.duration) + assertEquals("slow", result.max.metadata) + } + + @Test + fun `get all time min max returns null when empty`() { + val tracker = NetworkRequestTracker() + assertNull(tracker.allTimeMinMax) + } + + @Test + fun `get all time min max returns consistent pair`() { + val tracker = NetworkRequestTracker() + tracker.insertSample(100.milliseconds, "fast") + tracker.insertSample(500.milliseconds, "slow") + + val result = assertNotNull(tracker.allTimeMinMax) + assertEquals(100.milliseconds, result.min.duration) + assertEquals("fast", result.min.metadata) + assertEquals(500.milliseconds, result.max.duration) + assertEquals("slow", result.max.metadata) + } + + @Test + fun `get all time min max with single sample returns same for both`() { + val tracker = NetworkRequestTracker() + tracker.insertSample(250.milliseconds, "only") + + val result = assertNotNull(tracker.allTimeMinMax) + assertEquals(250.milliseconds, result.min.duration) + assertEquals(250.milliseconds, result.max.duration) + } + + // ---- Insert sample ---- + + @Test + fun `insert sample increments sample count`() { + val tracker = NetworkRequestTracker() + tracker.insertSample(50.milliseconds) + tracker.insertSample(100.milliseconds) + assertEquals(2, tracker.sampleCount) + } + + // ---- Metadata passthrough ---- + + @Test + fun `metadata passes through to sample`() { + val tracker = NetworkRequestTracker() + tracker.insertSample(10.milliseconds, "reducer:AddPlayer") + assertEquals("reducer:AddPlayer", tracker.allTimeMinMax?.min?.metadata) + } + + @Test + fun `finish tracking with override metadata`() { + val tracker = NetworkRequestTracker() + val id = tracker.startTrackingRequest("original") + tracker.finishTrackingRequest(id, "override") + assertEquals("override", tracker.allTimeMinMax?.min?.metadata) + } + + // ---- Windowed min/max ---- + + @Test + fun `get min max times returns null before window elapses`() { + val tracker = NetworkRequestTracker() + tracker.insertSample(100.milliseconds) + // The first window hasn't completed yet, so lastWindow is null + assertNull(tracker.minMaxTimes(10)) + } + + @Test + fun `multiple window sizes work independently`() { + val ts = TestTimeSource() + val tracker = NetworkRequestTracker(ts) + + // Register two window sizes + assertNull(tracker.minMaxTimes(1)) // 1-second window + assertNull(tracker.minMaxTimes(3)) // 3-second window + + // Window 1 (0s–1s): insert 100ms and 200ms + tracker.insertSample(100.milliseconds) + tracker.insertSample(200.milliseconds) + ts += 1.seconds + + // 1s window should have data; 3s window still pending + val w1 = assertNotNull(tracker.minMaxTimes(1)) + assertEquals(100.milliseconds, w1.min.duration) + assertEquals(200.milliseconds, w1.max.duration) + assertNull(tracker.minMaxTimes(3)) + + // Window 2 (1s–2s): insert 500ms only + tracker.insertSample(500.milliseconds) + ts += 1.seconds + + // 1s window rotated to new data; 3s window still pending + val w2 = assertNotNull(tracker.minMaxTimes(1)) + assertEquals(500.milliseconds, w2.min.duration) + assertNull(tracker.minMaxTimes(3)) + + // Advance to 3s — now the 3s window should have data from all samples + ts += 1.seconds + val w3 = assertNotNull(tracker.minMaxTimes(3)) + assertEquals(100.milliseconds, w3.min.duration) + assertEquals(500.milliseconds, w3.max.duration) + } + + @Test + fun `window rotation returns min max after window elapses`() { + val ts = TestTimeSource() + val tracker = NetworkRequestTracker(ts) + + // Register a 1-second window tracker + assertNull(tracker.minMaxTimes(1)) + + // Insert samples in the first window + tracker.insertSample(100.milliseconds, "fast") + tracker.insertSample(500.milliseconds, "slow") + tracker.insertSample(250.milliseconds, "mid") + + // Still within the first window — lastWindow has no data yet + assertNull(tracker.minMaxTimes(1)) + + // Advance past the 1-second window boundary + ts += 1.seconds + + // Now the previous window's data should be available + val result = assertNotNull(tracker.minMaxTimes(1)) + assertEquals(100.milliseconds, result.min.duration) + assertEquals("fast", result.min.metadata) + assertEquals(500.milliseconds, result.max.duration) + assertEquals("slow", result.max.metadata) + } + + @Test + fun `window rotation replaces with new window data`() { + val ts = TestTimeSource() + val tracker = NetworkRequestTracker(ts) + + // First window: samples 100ms and 500ms + tracker.minMaxTimes(1) // create tracker + tracker.insertSample(100.milliseconds, "w1-fast") + tracker.insertSample(500.milliseconds, "w1-slow") + + // Advance to second window + ts += 1.seconds + + // Insert new samples in the second window + tracker.insertSample(200.milliseconds, "w2-fast") + tracker.insertSample(300.milliseconds, "w2-slow") + + // getMinMax should return first window's data (100ms, 500ms) + val result1 = assertNotNull(tracker.minMaxTimes(1)) + assertEquals(100.milliseconds, result1.min.duration) + assertEquals(500.milliseconds, result1.max.duration) + + // Advance to third window — now second window becomes lastWindow + ts += 1.seconds + + val result2 = assertNotNull(tracker.minMaxTimes(1)) + assertEquals(200.milliseconds, result2.min.duration) + assertEquals("w2-fast", result2.min.metadata) + assertEquals(300.milliseconds, result2.max.duration) + assertEquals("w2-slow", result2.max.metadata) + } + + @Test + fun `window rotation returns null after two windows with no data`() { + val ts = TestTimeSource() + val tracker = NetworkRequestTracker(ts) + + // Insert samples in the first window + tracker.minMaxTimes(1) + tracker.insertSample(100.milliseconds, "data") + + // Advance past one window — data visible + ts += 1.seconds + assertNotNull(tracker.minMaxTimes(1)) + + // Advance past two full windows with no new data — + // the immediately preceding window is empty + ts += 2.seconds + assertNull(tracker.minMaxTimes(1)) + } + + @Test + fun `window rotation empty window preserves null min max`() { + val ts = TestTimeSource() + val tracker = NetworkRequestTracker(ts) + + // First window: insert data + tracker.minMaxTimes(1) + tracker.insertSample(100.milliseconds) + + // Advance to second window, insert nothing + ts += 1.seconds + + // First window data is available + assertNotNull(tracker.minMaxTimes(1)) + + // Advance to third window — second window had no data + ts += 1.seconds + + // lastWindow should be null since second window was empty + assertNull(tracker.minMaxTimes(1)) + } + + @Test + fun `window min max tracks extremes within window`() { + val ts = TestTimeSource() + val tracker = NetworkRequestTracker(ts) + tracker.minMaxTimes(1) + + // Insert samples that get progressively larger and smaller + tracker.insertSample(300.milliseconds, "mid") + tracker.insertSample(100.milliseconds, "smallest") + tracker.insertSample(900.milliseconds, "largest") + tracker.insertSample(200.milliseconds, "small") + + ts += 1.seconds + + val result = assertNotNull(tracker.minMaxTimes(1)) + assertEquals(100.milliseconds, result.min.duration) + assertEquals("smallest", result.min.metadata) + assertEquals(900.milliseconds, result.max.duration) + assertEquals("largest", result.max.metadata) + } + + @Test + fun `max trackers limit enforced`() { + val tracker = NetworkRequestTracker() + // Register 16 distinct window sizes (the max) + for (i in 1..16) { + tracker.minMaxTimes(i) + } + // 17th should throw + assertFailsWith { + tracker.minMaxTimes(17) + } + } + + // ---- Stats aggregator ---- + + @Test + fun `stats has all trackers`() { + val stats = Stats() + // Just verify the trackers are distinct instances + assertNotNull(stats.reducerRequestTracker) + assertNotNull(stats.procedureRequestTracker) + assertNotNull(stats.subscriptionRequestTracker) + assertNotNull(stats.oneOffRequestTracker) + assertNotNull(stats.applyMessageTracker) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SubscriptionEdgeCaseTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SubscriptionEdgeCaseTest.kt new file mode 100644 index 00000000000..e0a3ee78564 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SubscriptionEdgeCaseTest.kt @@ -0,0 +1,535 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import kotlinx.coroutines.CoroutineExceptionHandler +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) +class SubscriptionEdgeCaseTest { + + // ========================================================================= + // Subscription Lifecycle Edge Cases + // ========================================================================= + + @Test + fun `subscription state transitions pending to active to ended`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM t")) + assertEquals(SubscriptionState.PENDING, handle.state) + assertTrue(handle.isPending) + + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + assertEquals(SubscriptionState.ACTIVE, handle.state) + assertTrue(handle.isActive) + + handle.unsubscribe() + assertEquals(SubscriptionState.UNSUBSCRIBING, handle.state) + assertTrue(handle.isUnsubscribing) + + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 2u, + querySetId = handle.querySetId, + rows = null, + ) + ) + advanceUntilIdle() + assertEquals(SubscriptionState.ENDED, handle.state) + assertTrue(handle.isEnded) + + conn.disconnect() + } + + @Test + fun `unsubscribe from unsubscribing state throws`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM t")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + + handle.unsubscribe() + assertTrue(handle.isUnsubscribing) + + // Second unsubscribe should fail — already unsubscribing + assertFailsWith { + handle.unsubscribe() + } + conn.disconnect() + } + + @Test + fun `subscription error from pending state ends subscription`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var errorReceived = false + val handle = conn.subscribe( + queries = listOf("SELECT * FROM bad"), + onError = listOf { _, _ -> errorReceived = true }, + ) + assertTrue(handle.isPending) + + transport.sendToClient( + ServerMessage.SubscriptionError( + requestId = 1u, + querySetId = handle.querySetId, + error = "parse error", + ) + ) + advanceUntilIdle() + + assertTrue(handle.isEnded) + assertTrue(errorReceived) + // Should not be able to unsubscribe + assertFailsWith { handle.unsubscribe() } + conn.disconnect() + } + + @Test + fun `multiple subscriptions track independently`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle1 = conn.subscribe(listOf("SELECT * FROM t1")) + val handle2 = conn.subscribe(listOf("SELECT * FROM t2")) + + // Both start PENDING + assertTrue(handle1.isPending) + assertTrue(handle2.isPending) + + // Apply only handle1 + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle1.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + + assertTrue(handle1.isActive) + assertTrue(handle2.isPending) // handle2 still pending + + // Apply handle2 + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 2u, + querySetId = handle2.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + + assertTrue(handle1.isActive) + assertTrue(handle2.isActive) + conn.disconnect() + } + + @Test + fun `disconnect marks all pending and active subscriptions as ended`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val pending = conn.subscribe(listOf("SELECT * FROM t1")) + val active = conn.subscribe(listOf("SELECT * FROM t2")) + + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = active.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + + assertTrue(pending.isPending) + assertTrue(active.isActive) + + conn.disconnect() + advanceUntilIdle() + + assertTrue(pending.isEnded) + assertTrue(active.isEnded) + } + + @Test + fun `unsubscribe applied with rows removes from cache`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val row = SampleRow(1, "Alice") + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row.encode())))), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) + + // Unsubscribe with rows returned + handle.unsubscribeThen {} + advanceUntilIdle() + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 2u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row.encode())))), + ) + ) + advanceUntilIdle() + + assertEquals(0, cache.count()) + conn.disconnect() + } + + // ========================================================================= + // Unsubscribe with Null Rows + // ========================================================================= + + @Test + fun `unsubscribe applied with null rows does not delete from cache`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val row = SampleRow(1, "Alice") + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row.encode())))), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) + + // Unsubscribe without SendDroppedRows — server sends null rows + handle.unsubscribeThen {} + advanceUntilIdle() + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 2u, + querySetId = handle.querySetId, + rows = null, + ) + ) + advanceUntilIdle() + + // Row stays in cache when rows is null + assertEquals(1, cache.count()) + assertTrue(handle.isEnded) + conn.disconnect() + } + + // ========================================================================= + // Multiple Callbacks Registration + // ========================================================================= + + @Test + fun `multiple on applied callbacks all fire`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var count = 0 + val handle = conn.subscribe( + queries = listOf("SELECT * FROM t"), + onApplied = listOf( + { _ -> count++ }, + { _ -> count++ }, + { _ -> count++ }, + ), + ) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + + assertEquals(3, count) + conn.disconnect() + } + + @Test + fun `multiple on error callbacks all fire`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var count = 0 + val handle = conn.subscribe( + queries = listOf("SELECT * FROM t"), + onError = listOf( + { _, _ -> count++ }, + { _, _ -> count++ }, + ), + ) + transport.sendToClient( + ServerMessage.SubscriptionError( + requestId = 1u, + querySetId = handle.querySetId, + error = "oops", + ) + ) + advanceUntilIdle() + + assertEquals(2, count) + conn.disconnect() + } + + // ========================================================================= + // SubscribeApplied with Large Row Sets + // ========================================================================= + + @Test + fun `subscribe applied with many rows`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // 100 rows + val rows = (1..100).map { SampleRow(it, "Row$it") } + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows( + listOf( + SingleTableRows( + "sample", + buildRowList(*rows.map { it.encode() }.toTypedArray()) + ) + ) + ), + ) + ) + advanceUntilIdle() + + assertEquals(100, cache.count()) + conn.disconnect() + } + + // ========================================================================= + // SubscribeApplied for table not in cache + // ========================================================================= + + @Test + fun `subscribe applied for unregistered table ignores rows`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + // No cache registered for "sample" + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows( + listOf(SingleTableRows("sample", buildRowList(SampleRow(1, "Alice").encode()))) + ), + ) + ) + advanceUntilIdle() + + // Should not crash — rows for unregistered tables are silently skipped + assertTrue(conn.isActive) + assertTrue(handle.isActive) + conn.disconnect() + } + + + // ========================================================================= + // doUnsubscribe callback-vs-CAS race + // ========================================================================= + + @Test + fun `unsubscribe on ended subscription does not leak callback`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM t")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + assertEquals(SubscriptionState.ACTIVE, handle.state) + + // Server ends the subscription (e.g. SubscriptionError with null requestId triggers disconnect) + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 2u, + querySetId = handle.querySetId, + rows = null, + ) + ) + advanceUntilIdle() + assertEquals(SubscriptionState.ENDED, handle.state) + + // User tries to unsubscribe with a callback on the already-ended subscription. + // The callback must NOT fire — the CAS should fail and throw. + var callbackFired = false + assertFailsWith { + handle.unsubscribeThen { + callbackFired = true + } + } + advanceUntilIdle() + kotlin.test.assertFalse(callbackFired, "onEnd callback must not fire when CAS fails") + conn.disconnect() + } + + // ========================================================================= + // Concurrent subscribe + unsubscribe + // ========================================================================= + + @Test + fun `subscribe and immediate unsubscribe transitions correctly`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var appliedFired = false + var endFired = false + val handle = conn.subscribe( + queries = listOf("SELECT * FROM t"), + onApplied = listOf { _ -> appliedFired = true }, + ) + assertEquals(SubscriptionState.PENDING, handle.state) + + // Server confirms subscription + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + assertTrue(appliedFired) + assertEquals(SubscriptionState.ACTIVE, handle.state) + + // Immediately unsubscribe + handle.unsubscribeThen { _ -> endFired = true } + assertEquals(SubscriptionState.UNSUBSCRIBING, handle.state) + + // Server confirms unsubscribe + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 2u, + querySetId = handle.querySetId, + rows = null, + ) + ) + advanceUntilIdle() + assertTrue(endFired) + assertEquals(SubscriptionState.ENDED, handle.state) + conn.disconnect() + } + + @Test + fun `unsubscribe before applied throws`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM t")) + assertEquals(SubscriptionState.PENDING, handle.state) + + // Unsubscribe while still PENDING — CAS(ACTIVE→UNSUBSCRIBING) must fail + assertFailsWith { + handle.unsubscribe() + } + assertEquals(SubscriptionState.PENDING, handle.state) + conn.disconnect() + } + + @Test + fun `double unsubscribe throws`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport, exceptionHandler = CoroutineExceptionHandler { _, _ -> }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM t")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + assertEquals(SubscriptionState.ACTIVE, handle.state) + + handle.unsubscribe() + assertEquals(SubscriptionState.UNSUBSCRIBING, handle.state) + + // Second unsubscribe — state is UNSUBSCRIBING, not ACTIVE + assertFailsWith { + handle.unsubscribe() + } + conn.disconnect() + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SubscriptionIntegrationTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SubscriptionIntegrationTest.kt new file mode 100644 index 00000000000..9545a94857a --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/SubscriptionIntegrationTest.kt @@ -0,0 +1,1012 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) +class SubscriptionIntegrationTest { + + // --- Subscriptions --- + + @Test + fun `subscribe sends client message`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + conn.subscribe(listOf("SELECT * FROM player")) + advanceUntilIdle() + + val subMsg = transport.sentMessages.filterIsInstance().firstOrNull() + assertNotNull(subMsg) + assertEquals(listOf("SELECT * FROM player"), subMsg.queryStrings) + conn.disconnect() + } + + @Test + fun `subscribe applied fires on applied callback`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var applied = false + val handle = conn.subscribe( + queries = listOf("SELECT * FROM player"), + onApplied = listOf { _ -> applied = true }, + ) + + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + + assertTrue(applied) + assertTrue(handle.isActive) + conn.disconnect() + } + + @Test + fun `subscription error fires on error callback`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var errorMsg: String? = null + val handle = conn.subscribe( + queries = listOf("SELECT * FROM nonexistent"), + onError = listOf { _, err -> errorMsg = (err as SubscriptionError.ServerError).message }, + ) + + transport.sendToClient( + ServerMessage.SubscriptionError( + requestId = 1u, + querySetId = handle.querySetId, + error = "table not found", + ) + ) + advanceUntilIdle() + + assertEquals("table not found", errorMsg) + assertTrue(handle.isEnded) + conn.disconnect() + } + + // --- Unsubscribe lifecycle --- + + @Test + fun `unsubscribe then callback fires on unsubscribe applied`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var applied = false + val handle = conn.subscribe( + queries = listOf("SELECT * FROM sample"), + onApplied = listOf { _ -> applied = true }, + ) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + assertTrue(applied) + assertTrue(handle.isActive) + + var unsubEndFired = false + handle.unsubscribeThen { _ -> unsubEndFired = true } + advanceUntilIdle() + assertTrue(handle.isUnsubscribing) + + // Verify Unsubscribe message was sent + val unsubMsg = transport.sentMessages.filterIsInstance().firstOrNull() + assertNotNull(unsubMsg) + assertEquals(handle.querySetId, unsubMsg.querySetId) + + // Server confirms unsubscribe + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 2u, + querySetId = handle.querySetId, + rows = null, + ) + ) + advanceUntilIdle() + + assertTrue(unsubEndFired) + assertTrue(handle.isEnded) + conn.disconnect() + } + + @Test + fun `unsubscribe then callback is set before message sent`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe( + queries = listOf("SELECT * FROM sample"), + onApplied = listOf { _ -> }, + ) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + assertTrue(handle.isActive) + + var callbackFired = false + handle.unsubscribeThen { _ -> callbackFired = true } + advanceUntilIdle() + + assertTrue(handle.isUnsubscribing) + + // Simulate immediate server response + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 2u, + querySetId = handle.querySetId, + rows = null, + ) + ) + advanceUntilIdle() + + assertTrue(callbackFired, "Callback should fire even with immediate server response") + conn.disconnect() + } + + // --- Unsubscribe from wrong state --- + + @Test + fun `unsubscribe from pending state throws`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM player")) + // Handle is PENDING — no SubscribeApplied received yet + assertTrue(handle.isPending) + + assertFailsWith { + handle.unsubscribe() + } + conn.disconnect() + } + + @Test + fun `unsubscribe from ended state throws`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe( + queries = listOf("SELECT * FROM player"), + onError = listOf { _, _ -> }, + ) + + // Force ENDED via SubscriptionError + transport.sendToClient( + ServerMessage.SubscriptionError( + requestId = 1u, + querySetId = handle.querySetId, + error = "error", + ) + ) + advanceUntilIdle() + assertTrue(handle.isEnded) + + assertFailsWith { + handle.unsubscribe() + } + conn.disconnect() + } + + // --- Unsubscribe with custom flags --- + + @Test + fun `unsubscribe with send dropped rows flag`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val handle = conn.subscribe(listOf("SELECT * FROM player")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + assertTrue(handle.isActive) + + handle.unsubscribe() + advanceUntilIdle() + + val unsub = transport.sentMessages.filterIsInstance().last() + assertEquals(handle.querySetId, unsub.querySetId) + assertEquals(UnsubscribeFlags.SendDroppedRows, unsub.flags) // hardcoded internally + conn.disconnect() + } + + // --- Subscription state machine edge cases --- + + @Test + fun `subscription error while unsubscribing moves to ended`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var errorMsg: String? = null + val handle = conn.subscribe( + queries = listOf("SELECT * FROM sample"), + onError = listOf { _, err -> errorMsg = (err as SubscriptionError.ServerError).message }, + ) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + assertTrue(handle.isActive) + + // Start unsubscribing + handle.unsubscribe() + advanceUntilIdle() + assertTrue(handle.isUnsubscribing) + + // Server sends error instead of UnsubscribeApplied + transport.sendToClient( + ServerMessage.SubscriptionError( + requestId = 2u, + querySetId = handle.querySetId, + error = "internal error during unsubscribe", + ) + ) + advanceUntilIdle() + + assertTrue(handle.isEnded) + assertEquals("internal error during unsubscribe", errorMsg) + conn.disconnect() + } + + @Test + fun `transaction update during unsubscribe still applies`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val row = SampleRow(1, "Alice") + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row.encode())))), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) + + // Start unsubscribing + handle.unsubscribe() + advanceUntilIdle() + assertTrue(handle.isUnsubscribing) + + // A transaction arrives while unsubscribe is in-flight — row is inserted + val newRow = SampleRow(2, "Bob") + transport.sendToClient( + ServerMessage.TransactionUpdateMsg( + update = TransactionUpdate( + listOf( + QuerySetUpdate( + handle.querySetId, + listOf( + TableUpdate( + "sample", + listOf(TableUpdateRows.PersistentTable( + inserts = buildRowList(newRow.encode()), + deletes = buildRowList(), + )) + ) + ), + ) + ) + ) + ) + ) + advanceUntilIdle() + + // Transaction should still be applied to cache + assertEquals(2, cache.count()) + conn.disconnect() + } + + // --- Overlapping subscriptions --- + + @Test + fun `overlapping subscriptions ref count rows`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val row = SampleRow(1, "Alice") + val encodedRow = row.encode() + + var insertCount = 0 + var deleteCount = 0 + cache.onInsert { _, _ -> insertCount++ } + cache.onDelete { _, _ -> deleteCount++ } + + // First subscription inserts row + val handle1 = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle1.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(encodedRow)))), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) + assertEquals(1, insertCount) // onInsert fires for first occurrence + + // Second subscription also inserts the same row — ref count goes to 2 + val handle2 = conn.subscribe(listOf("SELECT * FROM sample WHERE id = 1")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 2u, + querySetId = handle2.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(encodedRow)))), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) // Still one row (ref count = 2) + assertEquals(1, insertCount) // onInsert does NOT fire again + + // First subscription unsubscribes — ref count decrements to 1 + handle1.unsubscribeThen {} + advanceUntilIdle() + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 3u, + querySetId = handle1.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(encodedRow)))), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) // Row still present (ref count = 1) + assertEquals(0, deleteCount) // onDelete does NOT fire + + // Second subscription unsubscribes — ref count goes to 0 + handle2.unsubscribeThen {} + advanceUntilIdle() + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 4u, + querySetId = handle2.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(encodedRow)))), + ) + ) + advanceUntilIdle() + assertEquals(0, cache.count()) // Row removed + assertEquals(1, deleteCount) // onDelete fires now + + conn.disconnect() + } + + @Test + fun `overlapping subscription transaction update affects both handles`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val row = SampleRow(1, "Alice") + val encodedRow = row.encode() + + // Two subscriptions on the same table + val handle1 = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle1.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(encodedRow)))), + ) + ) + advanceUntilIdle() + + val handle2 = conn.subscribe(listOf("SELECT * FROM sample WHERE id = 1")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 2u, + querySetId = handle2.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(encodedRow)))), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) // ref count = 2 + + // A TransactionUpdate that updates the row (delete old + insert new) + val updatedRow = SampleRow(1, "Alice Updated") + var updateOld: SampleRow? = null + var updateNew: SampleRow? = null + cache.onUpdate { _, old, new -> updateOld = old; updateNew = new } + + transport.sendToClient( + ServerMessage.TransactionUpdateMsg( + TransactionUpdate( + listOf( + QuerySetUpdate( + handle1.querySetId, + listOf( + TableUpdate( + "sample", + listOf( + TableUpdateRows.PersistentTable( + inserts = buildRowList(updatedRow.encode()), + deletes = buildRowList(encodedRow), + ) + ) + ) + ) + ) + ) + ) + ) + ) + advanceUntilIdle() + + // The row should be updated in the cache + assertEquals(1, cache.count()) + assertEquals("Alice Updated", cache.all().first().name) + assertEquals(row, updateOld) + assertEquals(updatedRow, updateNew) + + // After unsubscribing handle1, the row still has ref count from handle2 + handle1.unsubscribeThen {} + advanceUntilIdle() + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 3u, + querySetId = handle1.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(updatedRow.encode())))), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) // Still present via handle2 + assertEquals("Alice Updated", cache.all().first().name) + + conn.disconnect() + } + + // --- Multi-subscription conflict scenarios --- + + @Test + fun `multiple subscriptions independent lifecycle`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var applied1 = false + var applied2 = false + val handle1 = conn.subscribe( + queries = listOf("SELECT * FROM players"), + onApplied = listOf { _ -> applied1 = true }, + ) + val handle2 = conn.subscribe( + queries = listOf("SELECT * FROM items"), + onApplied = listOf { _ -> applied2 = true }, + ) + advanceUntilIdle() + + // Only first subscription is confirmed + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle1.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + + assertTrue(applied1) + assertFalse(applied2) + assertTrue(handle1.isActive) + assertTrue(handle2.isPending) + + // Unsubscribe first while second is still pending + handle1.unsubscribe() + advanceUntilIdle() + assertTrue(handle1.isUnsubscribing) + assertTrue(handle2.isPending) + + // Second subscription confirmed + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 2u, + querySetId = handle2.querySetId, + rows = emptyQueryRows(), + ) + ) + advanceUntilIdle() + + assertTrue(applied2) + assertTrue(handle2.isActive) + assertTrue(handle1.isUnsubscribing) + + // First unsubscribe confirmed + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 3u, + querySetId = handle1.querySetId, + rows = null, + ) + ) + advanceUntilIdle() + + assertTrue(handle1.isEnded) + assertTrue(handle2.isActive) + conn.disconnect() + } + + @Test + fun `subscribe applied during unsubscribe of overlapping subscription`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val sharedRow = SampleRow(1, "Alice") + val sub1OnlyRow = SampleRow(2, "Bob") + + // Sub1: gets both rows + val handle1 = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle1.querySetId, + rows = QueryRows( + listOf(SingleTableRows("sample", buildRowList(sharedRow.encode(), sub1OnlyRow.encode()))) + ), + ) + ) + advanceUntilIdle() + assertEquals(2, cache.count()) + + // Start unsubscribing sub1 + handle1.unsubscribeThen {} + advanceUntilIdle() + assertTrue(handle1.isUnsubscribing) + + // Sub2 arrives while sub1 unsubscribe is in-flight — shares one row + val handle2 = conn.subscribe(listOf("SELECT * FROM sample WHERE id = 1")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 2u, + querySetId = handle2.querySetId, + rows = QueryRows( + listOf(SingleTableRows("sample", buildRowList(sharedRow.encode()))) + ), + ) + ) + advanceUntilIdle() + assertTrue(handle2.isActive) + // sharedRow now has ref count 2, sub1OnlyRow has ref count 1 + assertEquals(2, cache.count()) + + // Sub1 unsubscribe completes — drops both rows by ref count + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 3u, + querySetId = handle1.querySetId, + rows = QueryRows( + listOf(SingleTableRows("sample", buildRowList(sharedRow.encode(), sub1OnlyRow.encode()))) + ), + ) + ) + advanceUntilIdle() + + // sharedRow survives (ref count 2 -> 1), sub1OnlyRow removed (ref count 1 -> 0) + assertEquals(1, cache.count()) + assertEquals(sharedRow, cache.all().single()) + assertTrue(handle1.isEnded) + assertTrue(handle2.isActive) + conn.disconnect() + } + + @Test + fun `subscription error does not affect other subscription cached rows`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val row = SampleRow(1, "Alice") + + // Sub1: active with a row in cache + val handle1 = conn.subscribe( + queries = listOf("SELECT * FROM sample"), + ) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle1.querySetId, + rows = QueryRows( + listOf(SingleTableRows("sample", buildRowList(row.encode()))) + ), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) + assertTrue(handle1.isActive) + + // Sub2: errors during subscribe (requestId present = non-fatal) + var sub2Error: SubscriptionError? = null + val handle2 = conn.subscribe( + queries = listOf("SELECT * FROM sample WHERE invalid"), + onError = listOf { _, err -> sub2Error = err }, + ) + transport.sendToClient( + ServerMessage.SubscriptionError( + requestId = 2u, + querySetId = handle2.querySetId, + error = "parse error", + ) + ) + advanceUntilIdle() + + // Sub2 is ended, but sub1's row must still be in cache + assertTrue(handle2.isEnded) + assertNotNull(sub2Error) + assertTrue(handle1.isActive) + assertEquals(1, cache.count()) + assertEquals(row, cache.all().single()) + assertTrue(conn.isActive) + conn.disconnect() + } + + @Test + fun `transaction update spans multiple query sets`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val row1 = SampleRow(1, "Alice") + val row2 = SampleRow(2, "Bob") + + // Two subscriptions on the same table + val handle1 = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle1.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row1.encode())))), + ) + ) + advanceUntilIdle() + + val handle2 = conn.subscribe(listOf("SELECT * FROM sample WHERE id = 2")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 2u, + querySetId = handle2.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList()))), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) + + // Single TransactionUpdate with updates from BOTH query sets + var insertCount = 0 + cache.onInsert { _, _ -> insertCount++ } + transport.sendToClient( + ServerMessage.TransactionUpdateMsg( + TransactionUpdate( + listOf( + QuerySetUpdate( + handle1.querySetId, + listOf( + TableUpdate( + "sample", + listOf(TableUpdateRows.PersistentTable( + inserts = buildRowList(row2.encode()), + deletes = buildRowList(), + )) + ) + ), + ), + QuerySetUpdate( + handle2.querySetId, + listOf( + TableUpdate( + "sample", + listOf(TableUpdateRows.PersistentTable( + inserts = buildRowList(row2.encode()), + deletes = buildRowList(), + )) + ) + ), + ), + ) + ) + ) + ) + advanceUntilIdle() + + // row2 inserted via both query sets — ref count = 2, but onInsert fires once + assertEquals(2, cache.count()) + assertEquals(1, insertCount) + conn.disconnect() + } + + @Test + fun `resubscribe after unsubscribe completes`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val row = SampleRow(1, "Alice") + + // First subscription + val handle1 = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle1.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row.encode())))), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) + + // Unsubscribe + handle1.unsubscribeThen {} + advanceUntilIdle() + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 2u, + querySetId = handle1.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row.encode())))), + ) + ) + advanceUntilIdle() + assertEquals(0, cache.count()) + assertTrue(handle1.isEnded) + + // Re-subscribe with the same query — fresh handle, row re-inserted + var reApplied = false + val handle2 = conn.subscribe( + queries = listOf("SELECT * FROM sample"), + onApplied = listOf { _ -> reApplied = true }, + ) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 3u, + querySetId = handle2.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row.encode())))), + ) + ) + advanceUntilIdle() + + assertTrue(reApplied) + assertTrue(handle2.isActive) + assertEquals(1, cache.count()) + assertEquals(row, cache.all().single()) + // Old handle stays ended + assertTrue(handle1.isEnded) + conn.disconnect() + } + + @Test + fun `three overlapping subscriptions unsubscribe middle`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val row = SampleRow(1, "Alice") + val encodedRow = row.encode() + + var deleteCount = 0 + cache.onDelete { _, _ -> deleteCount++ } + + // Three subscriptions all sharing the same row + val handle1 = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle1.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(encodedRow)))), + ) + ) + advanceUntilIdle() + + val handle2 = conn.subscribe(listOf("SELECT * FROM sample WHERE id = 1")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 2u, + querySetId = handle2.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(encodedRow)))), + ) + ) + advanceUntilIdle() + + val handle3 = conn.subscribe(listOf("SELECT * FROM sample WHERE name = 'Alice'")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 3u, + querySetId = handle3.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(encodedRow)))), + ) + ) + advanceUntilIdle() + // ref count = 3 + assertEquals(1, cache.count()) + + // Unsubscribe middle subscription + handle2.unsubscribeThen {} + advanceUntilIdle() + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 4u, + querySetId = handle2.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(encodedRow)))), + ) + ) + advanceUntilIdle() + + // ref count 3 -> 2, row still present, no onDelete + assertEquals(1, cache.count()) + assertEquals(0, deleteCount) + assertTrue(handle2.isEnded) + assertTrue(handle1.isActive) + assertTrue(handle3.isActive) + + // Unsubscribe first + handle1.unsubscribeThen {} + advanceUntilIdle() + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 5u, + querySetId = handle1.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(encodedRow)))), + ) + ) + advanceUntilIdle() + + // ref count 2 -> 1, still present + assertEquals(1, cache.count()) + assertEquals(0, deleteCount) + + // Unsubscribe last — ref count -> 0, row deleted + handle3.unsubscribeThen {} + advanceUntilIdle() + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 6u, + querySetId = handle3.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(encodedRow)))), + ) + ) + advanceUntilIdle() + + assertEquals(0, cache.count()) + assertEquals(1, deleteCount) + conn.disconnect() + } + + @Test + fun `unsubscribe drops unique rows but keeps shared rows`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val sharedRow = SampleRow(1, "Alice") + val sub1Only = SampleRow(2, "Bob") + val sub2Only = SampleRow(3, "Charlie") + + // Sub1: gets sharedRow + sub1Only + val handle1 = conn.subscribe(listOf("SELECT * FROM sample WHERE id <= 2")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle1.querySetId, + rows = QueryRows( + listOf(SingleTableRows("sample", buildRowList(sharedRow.encode(), sub1Only.encode()))) + ), + ) + ) + advanceUntilIdle() + assertEquals(2, cache.count()) + + // Sub2: gets sharedRow + sub2Only + val handle2 = conn.subscribe(listOf("SELECT * FROM sample WHERE id != 2")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 2u, + querySetId = handle2.querySetId, + rows = QueryRows( + listOf(SingleTableRows("sample", buildRowList(sharedRow.encode(), sub2Only.encode()))) + ), + ) + ) + advanceUntilIdle() + assertEquals(3, cache.count()) + + val deleted = mutableListOf() + cache.onDelete { _, row -> deleted.add(row.id) } + + // Unsubscribe sub1 — drops sharedRow (ref 2->1) and sub1Only (ref 1->0) + handle1.unsubscribeThen {} + advanceUntilIdle() + transport.sendToClient( + ServerMessage.UnsubscribeApplied( + requestId = 3u, + querySetId = handle1.querySetId, + rows = QueryRows( + listOf(SingleTableRows("sample", buildRowList(sharedRow.encode(), sub1Only.encode()))) + ), + ) + ) + advanceUntilIdle() + + // sub1Only deleted, sharedRow survives + assertEquals(2, cache.count()) + assertEquals(listOf(2), deleted) // only sub1Only's id + val remaining = cache.all().sortedBy { it.id } + assertEquals(listOf(sharedRow, sub2Only), remaining) + conn.disconnect() + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TableCacheIntegrationTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TableCacheIntegrationTest.kt new file mode 100644 index 00000000000..086ad7b9f83 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TableCacheIntegrationTest.kt @@ -0,0 +1,464 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) +class TableCacheIntegrationTest { + + // --- Table cache --- + + @Test + fun `table cache updates on subscribe applied`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val row = SampleRow(1, "Alice") + val rowList = buildRowList(row.encode()) + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", rowList))), + ) + ) + advanceUntilIdle() + + assertEquals(1, cache.count()) + assertEquals("Alice", cache.all().first().name) + conn.disconnect() + } + + @Test + fun `table cache inserts and deletes via transaction update`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // First insert a row via SubscribeApplied + val row1 = SampleRow(1, "Alice") + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row1.encode())))), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) + + // Now send a TransactionUpdate that inserts row2 and deletes row1 + val row2 = SampleRow(2, "Bob") + val inserts = buildRowList(row2.encode()) + val deletes = buildRowList(row1.encode()) + + transport.sendToClient( + ServerMessage.TransactionUpdateMsg( + TransactionUpdate( + listOf( + QuerySetUpdate( + handle.querySetId, + listOf( + TableUpdate( + "sample", + listOf(TableUpdateRows.PersistentTable(inserts, deletes)) + ) + ) + ) + ) + ) + ) + ) + advanceUntilIdle() + + assertEquals(1, cache.count()) + assertEquals("Bob", cache.all().first().name) + conn.disconnect() + } + + // --- Table callbacks through integration --- + + @Test + fun `table on insert fires on subscribe applied`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + var insertedRow: SampleRow? = null + cache.onInsert { _, row -> insertedRow = row } + + val row = SampleRow(1, "Alice") + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row.encode())))), + ) + ) + advanceUntilIdle() + + assertEquals(row, insertedRow) + conn.disconnect() + } + + @Test + fun `table on delete fires on transaction update`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Insert a row first + val row = SampleRow(1, "Alice") + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row.encode())))), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) + + var deletedRow: SampleRow? = null + cache.onDelete { _, r -> deletedRow = r } + + // Delete via TransactionUpdate + transport.sendToClient( + ServerMessage.TransactionUpdateMsg( + TransactionUpdate( + listOf( + QuerySetUpdate( + handle.querySetId, + listOf( + TableUpdate( + "sample", + listOf( + TableUpdateRows.PersistentTable( + inserts = buildRowList(), + deletes = buildRowList(row.encode()), + ) + ) + ) + ) + ) + ) + ) + ) + ) + advanceUntilIdle() + + assertEquals(row, deletedRow) + assertEquals(0, cache.count()) + conn.disconnect() + } + + @Test + fun `table on update fires on transaction update`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Insert a row first + val oldRow = SampleRow(1, "Alice") + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(oldRow.encode())))), + ) + ) + advanceUntilIdle() + + var updatedOld: SampleRow? = null + var updatedNew: SampleRow? = null + cache.onUpdate { _, old, new -> + updatedOld = old + updatedNew = new + } + + // Update: delete old row, insert new row with same PK + val newRow = SampleRow(1, "Alice Updated") + transport.sendToClient( + ServerMessage.TransactionUpdateMsg( + TransactionUpdate( + listOf( + QuerySetUpdate( + handle.querySetId, + listOf( + TableUpdate( + "sample", + listOf( + TableUpdateRows.PersistentTable( + inserts = buildRowList(newRow.encode()), + deletes = buildRowList(oldRow.encode()), + ) + ) + ) + ) + ) + ) + ) + ) + ) + advanceUntilIdle() + + assertEquals(oldRow, updatedOld) + assertEquals(newRow, updatedNew) + assertEquals(1, cache.count()) + assertEquals("Alice Updated", cache.all().first().name) + conn.disconnect() + } + + // --- onBeforeDelete --- + + @Test + fun `on before delete fires before mutation`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Insert a row + val row = SampleRow(1, "Alice") + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(row.encode())))), + ) + ) + advanceUntilIdle() + assertEquals(1, cache.count()) + + // Track onBeforeDelete — at callback time, the row should still be in the cache + var cacheCountDuringCallback: Int? = null + var beforeDeleteRow: SampleRow? = null + cache.onBeforeDelete { _, r -> + beforeDeleteRow = r + cacheCountDuringCallback = cache.count() + } + + // Delete via TransactionUpdate + transport.sendToClient( + ServerMessage.TransactionUpdateMsg( + TransactionUpdate( + listOf( + QuerySetUpdate( + handle.querySetId, + listOf( + TableUpdate( + "sample", + listOf( + TableUpdateRows.PersistentTable( + inserts = buildRowList(), + deletes = buildRowList(row.encode()), + ) + ) + ) + ) + ) + ) + ) + ) + ) + advanceUntilIdle() + + assertEquals(row, beforeDeleteRow) + assertEquals(1, cacheCountDuringCallback) // Row still present during onBeforeDelete + assertEquals(0, cache.count()) // Row removed after + conn.disconnect() + } + + // --- Cross-table preApply ordering --- + + @Test + fun `cross table pre apply runs before any apply`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + + // Set up two independent table caches + val cacheA = createSampleCache() + val cacheB = createSampleCache() + conn.clientCache.register("table_a", cacheA) + conn.clientCache.register("table_b", cacheB) + + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Subscribe and apply initial rows to both tables + val handle = conn.subscribe(listOf("SELECT * FROM table_a", "SELECT * FROM table_b")) + val rowA = SampleRow(1, "Alice") + val rowB = SampleRow(2, "Bob") + + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows( + listOf( + SingleTableRows("table_a", buildRowList(rowA.encode())), + SingleTableRows("table_b", buildRowList(rowB.encode())), + ) + ), + ) + ) + advanceUntilIdle() + assertEquals(1, cacheA.count()) + assertEquals(1, cacheB.count()) + + // Track event ordering: onBeforeDelete (preApply) vs onDelete (apply) + val events = mutableListOf() + cacheA.onBeforeDelete { _, _ -> events.add("preApply_A") } + cacheA.onDelete { _, _ -> events.add("apply_A") } + cacheB.onBeforeDelete { _, _ -> events.add("preApply_B") } + cacheB.onDelete { _, _ -> events.add("apply_B") } + + // Send a TransactionUpdate that deletes from BOTH tables + transport.sendToClient( + ServerMessage.TransactionUpdateMsg( + TransactionUpdate( + listOf( + QuerySetUpdate( + handle.querySetId, + listOf( + TableUpdate("table_a", listOf(TableUpdateRows.PersistentTable(buildRowList(), buildRowList(rowA.encode())))), + TableUpdate("table_b", listOf(TableUpdateRows.PersistentTable(buildRowList(), buildRowList(rowB.encode())))), + ) + ) + ) + ) + ) + ) + advanceUntilIdle() + + // The key invariant: ALL preApply callbacks fire before ANY apply callbacks + assertEquals(listOf("preApply_A", "preApply_B", "apply_A", "apply_B"), events) + conn.disconnect() + } + + // --- Unknown querySetId / requestId (silent early returns) --- + + @Test + fun `subscribe applied for unknown query set id is ignored`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Register a callback to verify it does NOT fire + var insertFired = false + cache.onInsert { _, _ -> insertFired = true } + + // Send SubscribeApplied for a querySetId that was never subscribed + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 99u, + querySetId = QuerySetId(999u), + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(SampleRow(1, "ghost").encode())))), + ) + ) + advanceUntilIdle() + + // Should not crash, no rows inserted, no callbacks fired + assertTrue(conn.isActive) + assertEquals(0, cache.count()) + assertFalse(insertFired) + conn.disconnect() + } + + @Test + fun `reducer result for unknown request id is ignored`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val cacheCountBefore = cache.count() + + // Send ReducerResultMsg with an Ok that has table updates — should be silently skipped + transport.sendToClient( + ServerMessage.ReducerResultMsg( + requestId = 999u, + timestamp = Timestamp.UNIX_EPOCH, + result = ReducerOutcome.OkEmpty, + ) + ) + advanceUntilIdle() + + assertTrue(conn.isActive) + assertEquals(cacheCountBefore, cache.count()) + conn.disconnect() + } + + @Test + fun `one off query result for unknown request id is ignored`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Register a real query so we can verify the unknown one doesn't interfere + var realCallbackFired = false + val realRequestId = conn.oneOffQuery("SELECT 1") { _ -> realCallbackFired = true } + advanceUntilIdle() + + // Send result for unknown requestId + transport.sendToClient( + ServerMessage.OneOffQueryResult( + requestId = 999u, + result = QueryResult.Ok(emptyQueryRows()), + ) + ) + advanceUntilIdle() + + // The unknown result should not fire the real callback + assertTrue(conn.isActive) + assertFalse(realCallbackFired) + + // Now send the real result — should fire + transport.sendToClient( + ServerMessage.OneOffQueryResult( + requestId = realRequestId, + result = QueryResult.Ok(emptyQueryRows()), + ) + ) + advanceUntilIdle() + assertTrue(realCallbackFired) + conn.disconnect() + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TableCacheTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TableCacheTest.kt new file mode 100644 index 00000000000..68d717f331d --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TableCacheTest.kt @@ -0,0 +1,1061 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.TableUpdateRows +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class TableCacheTest { + + @Test + fun `insert adds row`() { + val cache = createSampleCache() + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + assertEquals(1, cache.count()) + assertEquals(row, cache.all().single()) + } + + @Test + fun `insert multiple rows`() { + val cache = createSampleCache() + val r1 = SampleRow(1, "alice") + val r2 = SampleRow(2, "bob") + cache.applyInserts(STUB_CTX, buildRowList(r1.encode(), r2.encode())) + assertEquals(2, cache.count()) + val all = cache.all().sortedBy { it.id } + assertEquals(listOf(r1, r2), all) + } + + @Test + fun `insert duplicate key increments ref count`() { + val cache = createSampleCache() + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + assertEquals(1, cache.count()) + + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + assertEquals(1, cache.count()) + } + + @Test + fun `delete removes row`() { + val cache = createSampleCache() + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + assertEquals(1, cache.count()) + + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed) + assertEquals(0, cache.count()) + } + + @Test + fun `delete decrements ref count`() { + val cache = createSampleCache() + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + assertEquals(1, cache.count()) + + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed) + assertEquals(1, cache.count()) + + val parsed2 = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed2) + assertEquals(0, cache.count()) + } + + @Test + fun `update replaces row`() { + val cache = createSampleCache() + val oldRow = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(oldRow.encode())) + + val newRow = SampleRow(1, "alice_updated") + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(newRow.encode()), + deletes = buildRowList(oldRow.encode()), + ) + val parsed = cache.parseUpdate(update) + cache.applyUpdate(STUB_CTX, parsed) + + assertEquals(1, cache.count()) + assertEquals(newRow, cache.all().single()) + } + + @Test + fun `update fires internal listeners`() { + val cache = createSampleCache() + val oldRow = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(oldRow.encode())) + + val inserts = mutableListOf() + val deletes = mutableListOf() + cache.addInternalInsertListener { inserts.add(it) } + cache.addInternalDeleteListener { deletes.add(it) } + + val newRow = SampleRow(1, "alice_updated") + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(newRow.encode()), + deletes = buildRowList(oldRow.encode()), + ) + val parsed = cache.parseUpdate(update) + cache.applyUpdate(STUB_CTX, parsed) + + assertEquals(listOf(oldRow), deletes) + assertEquals(listOf(newRow), inserts) + } + + @Test + fun `event table does not store rows`() { + val cache = createSampleCache() + val row = SampleRow(1, "alice") + val event = TableUpdateRows.EventTable( + events = buildRowList(row.encode()), + ) + val parsed = cache.parseUpdate(event) + cache.applyUpdate(STUB_CTX, parsed) + + assertEquals(0, cache.count()) + } + + @Test + fun `clear empties all rows`() { + val cache = createSampleCache() + val r1 = SampleRow(1, "alice") + val r2 = SampleRow(2, "bob") + cache.applyInserts(STUB_CTX, buildRowList(r1.encode(), r2.encode())) + assertEquals(2, cache.count()) + + cache.clear() + assertEquals(0, cache.count()) + assertTrue(cache.all().isEmpty()) + } + + @Test + fun `clear fires internal delete listeners`() { + val cache = createSampleCache() + val r1 = SampleRow(1, "alice") + val r2 = SampleRow(2, "bob") + cache.applyInserts(STUB_CTX, buildRowList(r1.encode(), r2.encode())) + + val deleted = mutableListOf() + cache.addInternalDeleteListener { deleted.add(it) } + + cache.clear() + assertEquals(2, deleted.size) + assertTrue(deleted.containsAll(listOf(r1, r2))) + } + + @Test + fun `iter returns all rows`() { + val cache = createSampleCache() + val r1 = SampleRow(1, "alice") + val r2 = SampleRow(2, "bob") + cache.applyInserts(STUB_CTX, buildRowList(r1.encode(), r2.encode())) + + val iterated = cache.iter().sortedBy { it.id }.toList() + assertEquals(listOf(r1, r2), iterated) + } + + @Test + fun `internal insert listener fires on insert`() { + val cache = createSampleCache() + val inserted = mutableListOf() + cache.addInternalInsertListener { inserted.add(it) } + + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + assertEquals(listOf(row), inserted) + } + + @Test + fun `internal delete listener fires on delete`() { + val cache = createSampleCache() + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val deleted = mutableListOf() + cache.addInternalDeleteListener { deleted.add(it) } + + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed) + + assertEquals(listOf(row), deleted) + } + + @Test + fun `pure delete via update removes row`() { + val cache = createSampleCache() + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(), + deletes = buildRowList(row.encode()), + ) + val parsed = cache.parseUpdate(update) + cache.applyUpdate(STUB_CTX, parsed) + + assertEquals(0, cache.count()) + } + + @Test + fun `pure insert via update adds row`() { + val cache = createSampleCache() + val row = SampleRow(1, "alice") + + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(row.encode()), + deletes = buildRowList(), + ) + val parsed = cache.parseUpdate(update) + cache.applyUpdate(STUB_CTX, parsed) + + assertEquals(1, cache.count()) + assertEquals(row, cache.all().single()) + } + + @Test + fun `content key table works`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + assertEquals(1, cache.count()) + } + + // ---- Content-based keying extended coverage ---- + + @Test + fun `content key insert multiple distinct rows`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val r1 = SampleRow(1, "alice") + val r2 = SampleRow(2, "bob") + val r3 = SampleRow(1, "charlie") // same id, different name = different content key + cache.applyInserts(STUB_CTX, buildRowList(r1.encode(), r2.encode(), r3.encode())) + assertEquals(3, cache.count()) + val all = cache.all().sortedBy { it.name } + assertEquals(listOf(r1, r2, r3), all) + } + + @Test + fun `content key duplicate insert increments ref count`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + // Same content = same key, refcount bumped but only 1 logical row + assertEquals(1, cache.count()) + + // First delete decrements refcount but row survives + val parsed1 = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed1) + assertEquals(1, cache.count()) + + // Second delete removes the row + val parsed2 = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed2) + assertEquals(0, cache.count()) + } + + @Test + fun `content key delete matches by bytes not field values`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + // Different content (same id but different name) should NOT delete the original + val differentContent = SampleRow(1, "bob") + val parsed = cache.parseDeletes(buildRowList(differentContent.encode())) + cache.applyDeletes(STUB_CTX, parsed) + assertEquals(1, cache.count(), "Delete with different content should not affect existing row") + + // Delete with exact same content works + val exactMatch = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, exactMatch) + assertEquals(0, cache.count()) + } + + @Test + fun `content key on insert callback fires`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val inserted = mutableListOf() + cache.onInsert { _, row -> inserted.add(row) } + + val row = SampleRow(1, "alice") + val callbacks = cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + callbacks.forEach { it.invoke() } + + assertEquals(listOf(row), inserted) + } + + @Test + fun `content key on insert does not fire for duplicate content`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val inserted = mutableListOf() + cache.onInsert { _, r -> inserted.add(r) } + + // Same content again — refcount bump only, no callback + val callbacks = cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + callbacks.forEach { it.invoke() } + assertTrue(inserted.isEmpty()) + } + + @Test + fun `content key on delete callback fires`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val deleted = mutableListOf() + cache.onDelete { _, r -> deleted.add(r) } + + val parsed = cache.parseDeletes(buildRowList(row.encode())) + val callbacks = cache.applyDeletes(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + + assertEquals(listOf(row), deleted) + } + + @Test + fun `content key on delete does not fire when ref count still positive`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val row = SampleRow(1, "alice") + // Insert twice — refcount = 2 + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val deleted = mutableListOf() + cache.onDelete { _, r -> deleted.add(r) } + + // First delete decrements refcount but doesn't remove + val parsed = cache.parseDeletes(buildRowList(row.encode())) + val callbacks = cache.applyDeletes(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + assertTrue(deleted.isEmpty(), "onDelete should not fire when refcount > 0") + assertEquals(1, cache.count()) + } + + @Test + fun `content key on before delete fires`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val beforeDeletes = mutableListOf() + cache.onBeforeDelete { _, r -> beforeDeletes.add(r) } + + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.preApplyDeletes(STUB_CTX, parsed) + + assertEquals(listOf(row), beforeDeletes) + } + + @Test + fun `content key on before delete skips when ref count high`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val beforeDeletes = mutableListOf() + cache.onBeforeDelete { _, r -> beforeDeletes.add(r) } + + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.preApplyDeletes(STUB_CTX, parsed) + + assertTrue(beforeDeletes.isEmpty(), "onBeforeDelete should not fire when refcount > 1") + } + + @Test + fun `content key two phase delete order`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val events = mutableListOf() + cache.onBeforeDelete { _, _ -> events.add("before") } + cache.onDelete { _, _ -> events.add("delete") } + + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.preApplyDeletes(STUB_CTX, parsed) + val callbacks = cache.applyDeletes(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + + assertEquals(listOf("before", "delete"), events) + } + + @Test + fun `content key update always decomposes into delete and insert`() { + // For content-key tables, old and new content have different bytes = different keys. + // So a PersistentTable update with delete(old) + insert(new) is never merged into onUpdate. + val cache = TableCache.withContentKey(::decodeSampleRow) + val oldRow = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(oldRow.encode())) + + val updates = mutableListOf>() + val inserts = mutableListOf() + val deletes = mutableListOf() + cache.onUpdate { _, old, new -> updates.add(old to new) } + cache.onInsert { _, row -> inserts.add(row) } + cache.onDelete { _, row -> deletes.add(row) } + + val newRow = SampleRow(1, "alice_updated") + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(newRow.encode()), + deletes = buildRowList(oldRow.encode()), + ) + val parsed = cache.parseUpdate(update) + val callbacks = cache.applyUpdate(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + + // onUpdate never fires — different content = different keys + assertTrue(updates.isEmpty(), "onUpdate should never fire for content-key tables with different content") + assertEquals(listOf(newRow), inserts) + assertEquals(listOf(oldRow), deletes) + assertEquals(1, cache.count()) + } + + @Test + fun `content key same content delete and insert merges into update`() { + // Edge case: if delete and insert have IDENTICAL content (same bytes), + // they share the same content key and ARE merged into an onUpdate. + val cache = TableCache.withContentKey(::decodeSampleRow) + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val updates = mutableListOf>() + val inserts = mutableListOf() + val deletes = mutableListOf() + cache.onUpdate { _, old, new -> updates.add(old to new) } + cache.onInsert { _, r -> inserts.add(r) } + cache.onDelete { _, r -> deletes.add(r) } + + // Delete and insert exact same content + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(row.encode()), + deletes = buildRowList(row.encode()), + ) + val parsed = cache.parseUpdate(update) + val callbacks = cache.applyUpdate(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + + // Same content key in both sides → treated as update + assertEquals(1, updates.size) + assertEquals(row, updates[0].first) + assertEquals(row, updates[0].second) + assertTrue(inserts.isEmpty()) + assertTrue(deletes.isEmpty()) + assertEquals(1, cache.count()) + } + + @Test + fun `content key pre apply update fires before delete for pure deletes`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val row1 = SampleRow(1, "alice") + val row2 = SampleRow(2, "bob") + cache.applyInserts(STUB_CTX, buildRowList(row1.encode(), row2.encode())) + + val beforeDeletes = mutableListOf() + cache.onBeforeDelete { _, r -> beforeDeletes.add(r) } + + // Pure delete of row1 (no matching insert) + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(), + deletes = buildRowList(row1.encode()), + ) + val parsed = cache.parseUpdate(update) + cache.preApplyUpdate(STUB_CTX, parsed) + + assertEquals(listOf(row1), beforeDeletes) + } + + @Test + fun `content key pre apply update skips deletes that are updates`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val beforeDeletes = mutableListOf() + cache.onBeforeDelete { _, r -> beforeDeletes.add(r) } + + // Same content in both delete and insert = update, not pure delete + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(row.encode()), + deletes = buildRowList(row.encode()), + ) + val parsed = cache.parseUpdate(update) + cache.preApplyUpdate(STUB_CTX, parsed) + + assertTrue(beforeDeletes.isEmpty(), "onBeforeDelete should not fire for updates") + } + + @Test + fun `content key internal listeners fire correctly`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val internalInserts = mutableListOf() + val internalDeletes = mutableListOf() + cache.addInternalInsertListener { internalInserts.add(it) } + cache.addInternalDeleteListener { internalDeletes.add(it) } + + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + assertEquals(listOf(row), internalInserts) + assertTrue(internalDeletes.isEmpty()) + + internalInserts.clear() + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed) + assertEquals(listOf(row), internalDeletes) + assertTrue(internalInserts.isEmpty()) + } + + @Test + fun `content key internal listeners do not fire for ref count bump`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val internalInserts = mutableListOf() + cache.addInternalInsertListener { internalInserts.add(it) } + + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + assertEquals(1, internalInserts.size) + + // Same content again — refcount bump, no internal listener + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + assertEquals(1, internalInserts.size, "Internal listener should not fire for refcount bump") + } + + @Test + fun `content key iter and all`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val r1 = SampleRow(1, "alice") + val r2 = SampleRow(2, "bob") + val r3 = SampleRow(1, "charlie") // same id as r1 but different content key + cache.applyInserts(STUB_CTX, buildRowList(r1.encode(), r2.encode(), r3.encode())) + + val allRows = cache.all().sortedBy { it.name } + assertEquals(listOf(r1, r2, r3), allRows) + + val iterRows = cache.iter().sortedBy { it.name }.toList() + assertEquals(listOf(r1, r2, r3), iterRows) + } + + @Test + fun `content key clear removes all and fires internal listeners`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val r1 = SampleRow(1, "alice") + val r2 = SampleRow(2, "bob") + cache.applyInserts(STUB_CTX, buildRowList(r1.encode(), r2.encode())) + + val deleted = mutableListOf() + cache.addInternalDeleteListener { deleted.add(it) } + + cache.clear() + assertEquals(0, cache.count()) + assertTrue(cache.all().isEmpty()) + assertEquals(2, deleted.size) + assertTrue(deleted.containsAll(listOf(r1, r2))) + } + + @Test + fun `content key indexes work with content key cache`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val uniqueById = UniqueIndex(cache) { it.id } + val btreeByName = BTreeIndex(cache) { it.name } + + val r1 = SampleRow(1, "alice") + val r2 = SampleRow(2, "bob") + val r3 = SampleRow(3, "alice") // same name, different id + cache.applyInserts(STUB_CTX, buildRowList(r1.encode(), r2.encode(), r3.encode())) + + assertEquals(r1, uniqueById.find(1)) + assertEquals(r2, uniqueById.find(2)) + assertEquals(r3, uniqueById.find(3)) + assertEquals(2, btreeByName.filter("alice").size) + assertEquals(1, btreeByName.filter("bob").size) + + // Delete r1 — index updates + val parsed = cache.parseDeletes(buildRowList(r1.encode())) + cache.applyDeletes(STUB_CTX, parsed) + assertNull(uniqueById.find(1)) + assertEquals(1, btreeByName.filter("alice").size) + } + + @Test + fun `content key mixed update with pure delete and pure insert`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val existing1 = SampleRow(1, "alice") + val existing2 = SampleRow(2, "bob") + cache.applyInserts(STUB_CTX, buildRowList(existing1.encode(), existing2.encode())) + + val inserts = mutableListOf() + val deletes = mutableListOf() + val updates = mutableListOf>() + cache.onInsert { _, r -> inserts.add(r) } + cache.onDelete { _, r -> deletes.add(r) } + cache.onUpdate { _, old, new -> updates.add(old to new) } + + // Delete existing1, insert new row — these have different content keys + val newRow = SampleRow(3, "charlie") + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(newRow.encode()), + deletes = buildRowList(existing1.encode()), + ) + val parsed = cache.parseUpdate(update) + val callbacks = cache.applyUpdate(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + + assertEquals(listOf(newRow), inserts) + assertEquals(listOf(existing1), deletes) + assertTrue(updates.isEmpty()) + assertEquals(2, cache.count()) // existing2 + newRow + } + + @Test + fun `content key delete of non existent content is no op`() { + val cache = TableCache.withContentKey(::decodeSampleRow) + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val deleted = mutableListOf() + cache.onDelete { _, r -> deleted.add(r) } + + // Try to delete content that doesn't exist + val nonExistent = SampleRow(99, "nobody") + val parsed = cache.parseDeletes(buildRowList(nonExistent.encode())) + val callbacks = cache.applyDeletes(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + + assertTrue(deleted.isEmpty()) + assertEquals(1, cache.count()) + } + + @Test + fun `content key ref count with callback lifecycle`() { + // Full lifecycle: insert x3 (same content), delete x3, verify callback timing + val cache = TableCache.withContentKey(::decodeSampleRow) + val row = SampleRow(1, "alice") + + val inserts = mutableListOf() + val deletes = mutableListOf() + cache.onInsert { _, _ -> inserts.add(cache.count()) } + cache.onDelete { _, _ -> deletes.add(cache.count()) } + + // First insert → callback fires (count=1 after insert) + cache.applyInserts(STUB_CTX, buildRowList(row.encode())).forEach { it.invoke() } + assertEquals(listOf(1), inserts) + + // Second insert → no callback (refcount bump) + cache.applyInserts(STUB_CTX, buildRowList(row.encode())).forEach { it.invoke() } + assertEquals(listOf(1), inserts, "No callback on second insert") + + // Third insert → no callback + cache.applyInserts(STUB_CTX, buildRowList(row.encode())).forEach { it.invoke() } + assertEquals(listOf(1), inserts, "No callback on third insert") + + // First delete → no callback (refcount 3→2) + val p1 = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, p1).forEach { it.invoke() } + assertTrue(deletes.isEmpty(), "No delete callback while refcount > 0") + + // Second delete → no callback (refcount 2→1) + val p2 = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, p2).forEach { it.invoke() } + assertTrue(deletes.isEmpty(), "No delete callback while refcount > 0") + + // Third delete → callback fires (refcount 1→0, removed) + val p3 = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, p3).forEach { it.invoke() } + assertEquals(1, deletes.size, "Delete callback fires when row removed") + assertEquals(0, cache.count()) + } + + // ---- Public callback tests ---- + + @Test + fun `on insert callback fires`() { + val cache = createSampleCache() + val inserted = mutableListOf() + cache.onInsert { _, row -> inserted.add(row) } + + val row = SampleRow(1, "alice") + val callbacks = cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + callbacks.forEach { it.invoke() } + + assertEquals(listOf(row), inserted) + } + + @Test + fun `on insert callback does not fire for duplicate`() { + val cache = createSampleCache() + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val inserted = mutableListOf() + cache.onInsert { _, r -> inserted.add(r) } + + // Insert same key again — should NOT fire onInsert (ref count bump only) + val callbacks = cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + callbacks.forEach { it.invoke() } + + assertTrue(inserted.isEmpty()) + } + + @Test + fun `on delete callback fires`() { + val cache = createSampleCache() + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val deleted = mutableListOf() + cache.onDelete { _, r -> deleted.add(r) } + + val parsed = cache.parseDeletes(buildRowList(row.encode())) + val callbacks = cache.applyDeletes(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + + assertEquals(listOf(row), deleted) + } + + @Test + fun `on update callback fires`() { + val cache = createSampleCache() + val oldRow = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(oldRow.encode())) + + val updates = mutableListOf>() + cache.onUpdate { _, old, new -> updates.add(old to new) } + + val newRow = SampleRow(1, "alice_updated") + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(newRow.encode()), + deletes = buildRowList(oldRow.encode()), + ) + val parsed = cache.parseUpdate(update) + val callbacks = cache.applyUpdate(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + + assertEquals(1, updates.size) + assertEquals(oldRow, updates[0].first) + assertEquals(newRow, updates[0].second) + } + + @Test + fun `on before delete fires`() { + val cache = createSampleCache() + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val beforeDeletes = mutableListOf() + cache.onBeforeDelete { _, r -> beforeDeletes.add(r) } + + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.preApplyDeletes(STUB_CTX, parsed) + + assertEquals(listOf(row), beforeDeletes) + } + + @Test + fun `pre apply then apply deletes order correct`() { + val cache = createSampleCache() + val row = SampleRow(1, "alice") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + + val events = mutableListOf() + cache.onBeforeDelete { _, _ -> events.add("before") } + cache.onDelete { _, _ -> events.add("delete") } + + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.preApplyDeletes(STUB_CTX, parsed) // before fires here + val callbacks = cache.applyDeletes(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } // delete fires here + + assertEquals(listOf("before", "delete"), events) + } + + @Test + fun `remove on insert stops callback`() { + val cache = createSampleCache() + val inserted = mutableListOf() + val cb: (EventContext, SampleRow) -> Unit = { _, row -> inserted.add(row) } + cache.onInsert(cb) + + val r1 = SampleRow(1, "alice") + val callbacks1 = cache.applyInserts(STUB_CTX, buildRowList(r1.encode())) + callbacks1.forEach { it.invoke() } + assertEquals(1, inserted.size) + + cache.removeOnInsert(cb) + + val r2 = SampleRow(2, "bob") + val callbacks2 = cache.applyInserts(STUB_CTX, buildRowList(r2.encode())) + callbacks2.forEach { it.invoke() } + assertEquals(1, inserted.size) // no new insert + } + + @Test + fun `event table fires insert callbacks`() { + val cache = createSampleCache() + val inserted = mutableListOf() + cache.onInsert { _, row -> inserted.add(row) } + + val row = SampleRow(1, "event_row") + val event = TableUpdateRows.EventTable( + events = buildRowList(row.encode()), + ) + val parsed = cache.parseUpdate(event) + val callbacks = cache.applyUpdate(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + + // Event table rows fire callbacks but don't persist + assertEquals(1, inserted.size) + assertEquals(0, cache.count()) + } + + // ---- Event table extended coverage ---- + + @Test + fun `event table batch multiple rows`() { + val cache = createSampleCache() + val inserted = mutableListOf() + cache.onInsert { _, row -> inserted.add(row) } + + val rows = (1..10).map { SampleRow(it, "evt-$it") } + val event = TableUpdateRows.EventTable( + events = buildRowList(*rows.map { it.encode() }.toTypedArray()), + ) + val parsed = cache.parseUpdate(event) + val callbacks = cache.applyUpdate(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + + assertEquals(10, inserted.size) + assertEquals(rows, inserted) + assertEquals(0, cache.count()) + } + + @Test + fun `event table on delete callback never fires`() { + val cache = createSampleCache() + var deleteFired = false + cache.onDelete { _, _ -> deleteFired = true } + + val event = TableUpdateRows.EventTable( + events = buildRowList(SampleRow(1, "evt").encode()), + ) + val parsed = cache.parseUpdate(event) + val callbacks = cache.applyUpdate(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + + assertFalse(deleteFired, "onDelete should never fire for event tables") + } + + @Test + fun `event table on update callback never fires`() { + val cache = createSampleCache() + var updateFired = false + cache.onUpdate { _, _, _ -> updateFired = true } + + val event = TableUpdateRows.EventTable( + events = buildRowList(SampleRow(1, "evt").encode()), + ) + val parsed = cache.parseUpdate(event) + val callbacks = cache.applyUpdate(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + + assertFalse(updateFired, "onUpdate should never fire for event tables") + } + + @Test + fun `event table on before delete never fires`() { + val cache = createSampleCache() + var beforeDeleteFired = false + cache.onBeforeDelete { _, _ -> beforeDeleteFired = true } + + val event = TableUpdateRows.EventTable( + events = buildRowList(SampleRow(1, "evt").encode()), + ) + val parsed = cache.parseUpdate(event) + cache.preApplyUpdate(STUB_CTX, parsed) + val callbacks = cache.applyUpdate(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + + assertFalse(beforeDeleteFired, "onBeforeDelete should never fire for event tables") + } + + @Test + fun `event table remove on insert stops callback`() { + val cache = createSampleCache() + val inserted = mutableListOf() + val cb: (EventContext, SampleRow) -> Unit = { _, row -> inserted.add(row) } + cache.onInsert(cb) + + // First event fires callback + val event1 = TableUpdateRows.EventTable( + events = buildRowList(SampleRow(1, "first").encode()), + ) + val parsed1 = cache.parseUpdate(event1) + cache.applyUpdate(STUB_CTX, parsed1).forEach { it.invoke() } + assertEquals(1, inserted.size) + + // Remove callback, second event should NOT fire it + cache.removeOnInsert(cb) + val event2 = TableUpdateRows.EventTable( + events = buildRowList(SampleRow(2, "second").encode()), + ) + val parsed2 = cache.parseUpdate(event2) + cache.applyUpdate(STUB_CTX, parsed2).forEach { it.invoke() } + assertEquals(1, inserted.size, "Callback should not fire after removeOnInsert") + } + + @Test + fun `event table sequential updates never accumulate`() { + val cache = createSampleCache() + val allInserted = mutableListOf() + cache.onInsert { _, row -> allInserted.add(row) } + + // Send 5 sequential event updates + for (batch in 0 until 5) { + val rows = (1..3).map { SampleRow(batch * 3 + it, "b$batch-$it") } + val event = TableUpdateRows.EventTable( + events = buildRowList(*rows.map { it.encode() }.toTypedArray()), + ) + val parsed = cache.parseUpdate(event) + cache.applyUpdate(STUB_CTX, parsed).forEach { it.invoke() } + + // Cache must remain empty after every batch + assertEquals(0, cache.count(), "Cache should stay empty after event batch $batch") + } + + // All 15 callbacks should have fired + assertEquals(15, allInserted.size) + } + + @Test + fun `event table does not affect internal listeners`() { + val cache = createSampleCache() + val internalInserts = mutableListOf() + val internalDeletes = mutableListOf() + cache.addInternalInsertListener { internalInserts.add(it) } + cache.addInternalDeleteListener { internalDeletes.add(it) } + + val event = TableUpdateRows.EventTable( + events = buildRowList(SampleRow(1, "evt").encode()), + ) + val parsed = cache.parseUpdate(event) + cache.applyUpdate(STUB_CTX, parsed) + + // Internal listeners should NOT fire for event tables + assertTrue(internalInserts.isEmpty(), "Internal insert listener should not fire for event tables") + assertTrue(internalDeletes.isEmpty(), "Internal delete listener should not fire for event tables") + } + + @Test + fun `event table indexes stay empty`() { + val cache = createSampleCache() + val uniqueIndex = UniqueIndex(cache) { it.id } + val btreeIndex = BTreeIndex(cache) { it.name } + + val event = TableUpdateRows.EventTable( + events = buildRowList( + SampleRow(1, "evt-a").encode(), + SampleRow(2, "evt-b").encode(), + SampleRow(3, "evt-a").encode(), + ), + ) + val parsed = cache.parseUpdate(event) + cache.applyUpdate(STUB_CTX, parsed) + + // Indexes should remain empty since internal listeners don't fire + assertEquals(null, uniqueIndex.find(1)) + assertEquals(null, uniqueIndex.find(2)) + assertTrue(btreeIndex.filter("evt-a").isEmpty()) + assertTrue(btreeIndex.filter("evt-b").isEmpty()) + } + + @Test + fun `event table duplicate rows both fire callbacks`() { + val cache = createSampleCache() + val inserted = mutableListOf() + cache.onInsert { _, row -> inserted.add(row) } + + // Same row data sent twice — both should fire callbacks (no deduplication) + val row = SampleRow(1, "dup") + val event = TableUpdateRows.EventTable( + events = buildRowList(row.encode(), row.encode()), + ) + val parsed = cache.parseUpdate(event) + cache.applyUpdate(STUB_CTX, parsed).forEach { it.invoke() } + + assertEquals(2, inserted.size, "Both duplicate event rows should fire callbacks") + assertEquals(row, inserted[0]) + assertEquals(row, inserted[1]) + assertEquals(0, cache.count()) + } + + @Test + fun `event table after persistent insert does not affect cached rows`() { + val cache = createSampleCache() + + // Persistent insert + val persistentRow = SampleRow(1, "persistent") + cache.applyInserts(STUB_CTX, buildRowList(persistentRow.encode())) + assertEquals(1, cache.count()) + + // Event with same primary key — should NOT affect the cached row + val event = TableUpdateRows.EventTable( + events = buildRowList(SampleRow(1, "event-version").encode()), + ) + val parsed = cache.parseUpdate(event) + cache.applyUpdate(STUB_CTX, parsed) + + assertEquals(1, cache.count()) + assertEquals(persistentRow, cache.all().single(), "Persistent row should be untouched by event table update") + } + + @Test + fun `event table empty events produces no callbacks`() { + val cache = createSampleCache() + var callbackCount = 0 + cache.onInsert { _, _ -> callbackCount++ } + + val event = TableUpdateRows.EventTable( + events = buildRowList(), // empty + ) + val parsed = cache.parseUpdate(event) + val callbacks = cache.applyUpdate(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + + assertEquals(0, callbackCount, "Empty event table should produce no callbacks") + assertEquals(0, cache.count()) + } + + @Test + fun `event table multiple callbacks all fire`() { + val cache = createSampleCache() + val cb1 = mutableListOf() + val cb2 = mutableListOf() + val cb3 = mutableListOf() + cache.onInsert { _, row -> cb1.add(row) } + cache.onInsert { _, row -> cb2.add(row) } + cache.onInsert { _, row -> cb3.add(row) } + + val row = SampleRow(1, "evt") + val event = TableUpdateRows.EventTable( + events = buildRowList(row.encode()), + ) + val parsed = cache.parseUpdate(event) + cache.applyUpdate(STUB_CTX, parsed).forEach { it.invoke() } + + assertEquals(listOf(row), cb1) + assertEquals(listOf(row), cb2) + assertEquals(listOf(row), cb3) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TestHelpers.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TestHelpers.kt new file mode 100644 index 00000000000..63a9dbbca7c --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TestHelpers.kt @@ -0,0 +1,41 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.BsatnRowList +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.RowSizeHint + +data class SampleRow(val id: Int, val name: String) + +fun SampleRow.encode(): ByteArray { + val writer = BsatnWriter() + writer.writeI32(id) + writer.writeString(name) + return writer.toByteArray() +} + +fun decodeSampleRow(reader: BsatnReader): SampleRow { + val id = reader.readI32() + val name = reader.readString() + return SampleRow(id, name) +} + +fun buildRowList(vararg rows: ByteArray): BsatnRowList { + val writer = BsatnWriter() + val offsets = mutableListOf() + var offset = 0uL + for (row in rows) { + offsets.add(offset) + writer.writeRawBytes(row) + offset += row.size.toULong() + } + return BsatnRowList( + sizeHint = RowSizeHint.RowOffsets(offsets), + rowsData = writer.toByteArray(), + ) +} + +val STUB_CTX: EventContext = StubEventContext() + +fun createSampleCache(): TableCache = + TableCache.withPrimaryKey(::decodeSampleRow) { it.id } diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TransportAndFrameTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TransportAndFrameTest.kt new file mode 100644 index 00000000000..54fa6290c0a --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TransportAndFrameTest.kt @@ -0,0 +1,479 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.* +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.transport.SpacetimeTransport +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.test.StandardTestDispatcher +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import io.ktor.client.HttpClient +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertContains +import kotlin.test.assertTrue + +@OptIn(kotlinx.coroutines.ExperimentalCoroutinesApi::class) +class TransportAndFrameTest { + + // --- Mid-stream transport failures --- + + @Test + fun `transport error fires on disconnect with error`() = runTest { + val transport = FakeTransport() + var disconnectError: Throwable? = null + var disconnected = false + val conn = buildTestConnection(transport, onDisconnect = { _, err -> + disconnected = true + disconnectError = err + }) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + assertTrue(conn.isActive) + + // Simulate mid-stream transport error + val networkError = RuntimeException("connection reset by peer") + transport.closeWithError(networkError) + advanceUntilIdle() + + assertTrue(disconnected) + assertNotNull(disconnectError) + assertEquals("connection reset by peer", disconnectError!!.message) + conn.disconnect() + } + + @Test + fun `transport error fails pending subscription`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Subscribe but don't send SubscribeApplied + val handle = conn.subscribe(listOf("SELECT * FROM player")) + advanceUntilIdle() + assertTrue(handle.isPending) + + // Kill the transport — pending subscription should be failed + transport.closeWithError(RuntimeException("network error")) + advanceUntilIdle() + + assertTrue(handle.isEnded) + conn.disconnect() + } + + @Test + fun `transport error fails pending reducer callback`() = runTest { + val transport = FakeTransport() + val conn = buildTestConnection(transport) + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Call reducer but don't send result + var callbackFired = false + conn.callReducer("add", byteArrayOf(), "args", callback = { _ -> + callbackFired = true + }) + advanceUntilIdle() + + // Kill the transport — pending callback should be cleared + transport.closeWithError(RuntimeException("network error")) + advanceUntilIdle() + + // The callback should NOT have been fired (no result arrived) + assertFalse(callbackFired) + conn.disconnect() + } + + @Test + fun `send error does not crash receive loop`() = runTest { + val transport = FakeTransport() + // Use a CoroutineExceptionHandler so the unhandled send-loop exception + // doesn't propagate to runTest — we're testing that the receive loop survives. + val handler = kotlinx.coroutines.CoroutineExceptionHandler { _, _ -> } + val conn = DbConnection( + transport = transport, + scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler) + handler), + onConnectCallbacks = emptyList(), + onDisconnectCallbacks = emptyList(), + onConnectErrorCallbacks = emptyList(), + clientConnectionId = ConnectionId.random(), + stats = Stats(), + moduleDescriptor = null, + callbackDispatcher = null, + ) + conn.connect() + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + // Make sends fail + transport.sendError = RuntimeException("write failed") + + // The send loop dies, but the receive loop should still be active + conn.callReducer("add", byteArrayOf(), "args") + advanceUntilIdle() + + // Connection should still receive messages + val cache = createSampleCache() + conn.clientCache.register("sample", cache) + val handle = conn.subscribe(listOf("SELECT * FROM sample")) + advanceUntilIdle() + + // The subscribe message was dropped (send loop is dead), + // but we can still feed a SubscribeApplied to verify the receive loop is alive + transport.sendToClient( + ServerMessage.SubscribeApplied( + requestId = 1u, + querySetId = handle.querySetId, + rows = QueryRows(listOf(SingleTableRows("sample", buildRowList(SampleRow(1, "Alice").encode())))), + ) + ) + advanceUntilIdle() + + assertEquals(1, cache.count()) + conn.disconnect() + } + + // --- Raw transport: partial/corrupted frame handling --- + + @Test + fun `truncated bsatn frame fires on disconnect`() = runTest { + val rawTransport = RawFakeTransport() + var disconnectError: Throwable? = null + val conn = createConnectionWithTransport(rawTransport, onDisconnect = { _, err -> + disconnectError = err + }) + conn.connect() + advanceUntilIdle() + + // Send a valid InitialConnection first, then a truncated frame + val writer = BsatnWriter() + writer.writeSumTag(0u) // InitialConnection tag + writer.writeU256(TEST_IDENTITY.data) // identity + writer.writeU128(TEST_CONNECTION_ID.data) // connectionId + writer.writeString(TEST_TOKEN) // token + rawTransport.sendRawToClient(writer.toByteArray()) + advanceUntilIdle() + + // Now send a truncated frame — only the tag byte, missing all fields + rawTransport.sendRawToClient(byteArrayOf(0x00)) + advanceUntilIdle() + + assertNotNull(disconnectError) + conn.disconnect() + } + + @Test + fun `invalid server message tag fires on disconnect`() = runTest { + val rawTransport = RawFakeTransport() + var disconnectError: Throwable? = null + val conn = createConnectionWithTransport(rawTransport, onDisconnect = { _, err -> + disconnectError = err + }) + conn.connect() + advanceUntilIdle() + + // Send a frame with an invalid sum tag (255) + rawTransport.sendRawToClient(byteArrayOf(0xFF.toByte())) + advanceUntilIdle() + + assertNotNull(disconnectError) + assertTrue(disconnectError!!.message!!.contains("Unknown ServerMessage tag")) + conn.disconnect() + } + + @Test + fun `empty frame fires on disconnect`() = runTest { + val rawTransport = RawFakeTransport() + var disconnectError: Throwable? = null + val conn = createConnectionWithTransport(rawTransport, onDisconnect = { _, err -> + disconnectError = err + }) + conn.connect() + advanceUntilIdle() + + // Send an empty byte array — BsatnReader will fail to read even the tag byte + rawTransport.sendRawToClient(byteArrayOf()) + advanceUntilIdle() + + assertNotNull(disconnectError) + conn.disconnect() + } + + @Test + fun `truncated mid field disconnects`() = runTest { + // Valid tag (6 = ReducerResultMsg) + valid requestId, but truncated before timestamp + val rawTransport = RawFakeTransport() + var disconnectError: Throwable? = null + val conn = createConnectionWithTransport(rawTransport, onDisconnect = { _, err -> + disconnectError = err + }) + conn.connect() + rawTransport.sendRawToClient(encodeInitialConnectionBytes()) + advanceUntilIdle() + assertTrue(conn.isActive) + + val w = BsatnWriter() + w.writeSumTag(6u) // ReducerResultMsg + w.writeU32(1u) // requestId — valid + // Missing: timestamp + ReducerOutcome + rawTransport.sendRawToClient(w.toByteArray()) + advanceUntilIdle() + + assertNotNull(disconnectError, "Truncated mid-field should fire onDisconnect with error") + assertFalse(conn.isActive) + } + + @Test + fun `invalid nested option tag disconnects`() = runTest { + // SubscriptionError (tag 3) has Option for requestId — inject invalid option tag + val rawTransport = RawFakeTransport() + var disconnectError: Throwable? = null + val conn = createConnectionWithTransport(rawTransport, onDisconnect = { _, err -> + disconnectError = err + }) + conn.connect() + rawTransport.sendRawToClient(encodeInitialConnectionBytes()) + advanceUntilIdle() + + val w = BsatnWriter() + w.writeSumTag(3u) // SubscriptionError + w.writeSumTag(99u) // Invalid Option tag (should be 0=Some or 1=None) + rawTransport.sendRawToClient(w.toByteArray()) + advanceUntilIdle() + + assertNotNull(disconnectError) + assertTrue(disconnectError!!.message!!.contains("Invalid Option tag")) + } + + @Test + fun `invalid result tag in one off query disconnects`() = runTest { + // OneOffQueryResult (tag 5) has Result — inject invalid result tag + val rawTransport = RawFakeTransport() + var disconnectError: Throwable? = null + val conn = createConnectionWithTransport(rawTransport, onDisconnect = { _, err -> + disconnectError = err + }) + conn.connect() + rawTransport.sendRawToClient(encodeInitialConnectionBytes()) + advanceUntilIdle() + + val w = BsatnWriter() + w.writeSumTag(5u) // OneOffQueryResult + w.writeU32(42u) // requestId + w.writeSumTag(77u) // Invalid Result tag (should be 0=Ok or 1=Err) + rawTransport.sendRawToClient(w.toByteArray()) + advanceUntilIdle() + + assertNotNull(disconnectError) + assertTrue(disconnectError!!.message!!.contains("Invalid Result tag")) + } + + @Test + fun `oversized string length disconnects`() = runTest { + // Valid InitialConnection tag + identity + connectionId + string with huge length prefix + val rawTransport = RawFakeTransport() + var disconnectError: Throwable? = null + val conn = createConnectionWithTransport(rawTransport, onDisconnect = { _, err -> + disconnectError = err + }) + conn.connect() + advanceUntilIdle() + + val w = BsatnWriter() + w.writeSumTag(0u) // InitialConnection + w.writeU256(TEST_IDENTITY.data) + w.writeU128(TEST_CONNECTION_ID.data) + w.writeU32(0xFFFFFFFFu) // String length = 4GB — way more than remaining bytes + rawTransport.sendRawToClient(w.toByteArray()) + advanceUntilIdle() + + assertNotNull(disconnectError) + } + + @Test + fun `invalid reducer outcome tag disconnects`() = runTest { + // ReducerResultMsg (tag 6) with valid fields but invalid ReducerOutcome tag + val rawTransport = RawFakeTransport() + var disconnectError: Throwable? = null + val conn = createConnectionWithTransport(rawTransport, onDisconnect = { _, err -> + disconnectError = err + }) + conn.connect() + rawTransport.sendRawToClient(encodeInitialConnectionBytes()) + advanceUntilIdle() + + val w = BsatnWriter() + w.writeSumTag(6u) // ReducerResultMsg + w.writeU32(1u) // requestId + w.writeI64(12345L) // timestamp (Timestamp = i64 microseconds) + w.writeSumTag(200u) // Invalid ReducerOutcome tag + rawTransport.sendRawToClient(w.toByteArray()) + advanceUntilIdle() + + assertNotNull(disconnectError) + } + + @Test + fun `corrupt frame after established connection fails pending ops`() = runTest { + // Establish full connection with subscriptions/reducers, then corrupt frame + val rawTransport = RawFakeTransport() + var disconnectError: Throwable? = null + val conn = createConnectionWithTransport(rawTransport, onDisconnect = { _, err -> + disconnectError = err + }) + conn.connect() + rawTransport.sendRawToClient(encodeInitialConnectionBytes()) + advanceUntilIdle() + assertTrue(conn.isActive) + + // Fire a reducer call so there's a pending operation + var callbackFired = false + conn.callReducer("test", byteArrayOf(), "args", callback = { _ -> callbackFired = true }) + advanceUntilIdle() + assertEquals(1, conn.stats.reducerRequestTracker.requestsAwaitingResponse) + + // Corrupt frame kills the connection + rawTransport.sendRawToClient(byteArrayOf(0xFE.toByte())) + advanceUntilIdle() + + assertNotNull(disconnectError) + assertFalse(conn.isActive) + // Reducer callback should NOT have fired (it was discarded, not responded to) + assertFalse(callbackFired) + } + + @Test + fun `garbage after valid message is ignored`() = runTest { + // A fully valid InitialConnection with extra trailing bytes appended. + // BsatnReader doesn't check that all bytes are consumed, so this should work. + val rawTransport = RawFakeTransport() + var connected = false + var disconnectError: Throwable? = null + val conn = DbConnection( + transport = rawTransport, + scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)), + onConnectCallbacks = listOf { _, _, _ -> connected = true }, + onDisconnectCallbacks = listOf { _, err -> disconnectError = err }, + onConnectErrorCallbacks = emptyList(), + clientConnectionId = ConnectionId.random(), + stats = Stats(), + moduleDescriptor = null, + callbackDispatcher = null, + ) + conn.connect() + advanceUntilIdle() + + val validBytes = encodeInitialConnectionBytes() + val withTrailing = validBytes + byteArrayOf(0xDE.toByte(), 0xAD.toByte(), 0xBE.toByte(), 0xEF.toByte()) + rawTransport.sendRawToClient(withTrailing) + advanceUntilIdle() + + // Connection should succeed — trailing bytes are not consumed but not checked + assertTrue(connected, "Valid message with trailing garbage should still connect") + assertNull(disconnectError, "Trailing garbage should not cause disconnect") + conn.disconnect() + } + + @Test + fun `all zero bytes frame disconnects`() = runTest { + // A frame of all zeroes — tag 0 (InitialConnection) but fields are all zeroes, + // which will produce a truncated read since the string length is 0 but + // Identity (32 bytes) and ConnectionId (16 bytes) consume the buffer first + val rawTransport = RawFakeTransport() + var disconnectError: Throwable? = null + val conn = createConnectionWithTransport(rawTransport, onDisconnect = { _, err -> + disconnectError = err + }) + conn.connect() + advanceUntilIdle() + + // 10 zero bytes: tag=0 (InitialConnection), then only 9 bytes for Identity (needs 32) + rawTransport.sendRawToClient(ByteArray(10)) + advanceUntilIdle() + + assertNotNull(disconnectError) + } + + @Test + fun `valid tag with random garbage fields disconnects`() = runTest { + // SubscribeApplied (tag 1) followed by random garbage that doesn't form valid QueryRows + val rawTransport = RawFakeTransport() + var disconnectError: Throwable? = null + val conn = createConnectionWithTransport(rawTransport, onDisconnect = { _, err -> + disconnectError = err + }) + conn.connect() + rawTransport.sendRawToClient(encodeInitialConnectionBytes()) + advanceUntilIdle() + + val w = BsatnWriter() + w.writeSumTag(1u) // SubscribeApplied + w.writeU32(1u) // requestId + w.writeU32(1u) // querySetId + // QueryRows needs: array_len (u32) + table entries — write nonsensical large array len + w.writeU32(999999u) // array_len for QueryRows — far more than available bytes + rawTransport.sendRawToClient(w.toByteArray()) + advanceUntilIdle() + + assertNotNull(disconnectError) + } + + @Test + fun `valid frame after corrupted frame is not processed`() = runTest { + val rawTransport = RawFakeTransport() + var disconnected = false + val conn = createConnectionWithTransport(rawTransport, onDisconnect = { _, _ -> + disconnected = true + }) + conn.connect() + advanceUntilIdle() + + // Send a corrupted frame — this kills the receive loop + rawTransport.sendRawToClient(byteArrayOf(0xFF.toByte())) + advanceUntilIdle() + assertTrue(disconnected) + + // The connection is now disconnected; identity should NOT be set + // even if we somehow send a valid InitialConnection afterward + assertNull(conn.identity) + conn.disconnect() + } + + // --- Protocol validation --- + + @Test + fun `invalid protocol throws on connect`() = runTest { + val transport = SpacetimeTransport( + client = HttpClient(), + baseUrl = "ftp://example.com", + nameOrAddress = "test", + connectionId = ConnectionId.random(), + ) + val conn = DbConnection( + transport = transport, + scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)), + onConnectCallbacks = emptyList(), + onDisconnectCallbacks = emptyList(), + onConnectErrorCallbacks = emptyList(), + clientConnectionId = ConnectionId.random(), + stats = Stats(), + moduleDescriptor = null, + callbackDispatcher = null, + ) + var connectError: Throwable? = null + conn.onConnectError { _, err -> connectError = err } + + conn.connect() + advanceUntilIdle() + + val err = assertNotNull(connectError) + assertContains(assertNotNull(err.message), "Unsupported protocol") + assertFalse(conn.isActive) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TypeRoundTripTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TypeRoundTripTest.kt new file mode 100644 index 00000000000..40688740060 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/TypeRoundTripTest.kt @@ -0,0 +1,622 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Counter +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ScheduleAt +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.SpacetimeUuid +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.TimeDuration +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.UuidVersion +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.microseconds +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds + +class TypeRoundTripTest { + private fun encodeDecode(encode: (BsatnWriter) -> Unit, decode: (BsatnReader) -> T): T { + val writer = BsatnWriter() + encode(writer) + val reader = BsatnReader(writer.toByteArray()) + val result = decode(reader) + assertEquals(0, reader.remaining, "All bytes should be consumed") + return result + } + + // ---- ConnectionId ---- + + @Test + fun `connection id round trip`() { + val id = ConnectionId.random() + val decoded = encodeDecode({ id.encode(it) }, { ConnectionId.decode(it) }) + assertEquals(id, decoded) + } + + @Test + fun `connection id zero`() { + val zero = ConnectionId.zero() + assertTrue(zero.isZero()) + val decoded = encodeDecode({ zero.encode(it) }, { ConnectionId.decode(it) }) + assertEquals(zero, decoded) + assertTrue(decoded.isZero()) + } + + @Test + fun `connection id hex round trip`() { + val id = ConnectionId.random() + val hex = id.toHexString() + val restored = ConnectionId.fromHexString(hex) + assertEquals(id, restored) + } + + @Test + fun `connection id to byte array is little endian`() { + // ConnectionId with value 1 should have byte[0] = 1, rest zeros + val id = ConnectionId(BigInteger.ONE) + val bytes = id.toByteArray() + assertEquals(16, bytes.size) + assertEquals(1.toByte(), bytes[0]) + for (i in 1 until 16) { + assertEquals(0.toByte(), bytes[i], "Byte at index $i should be 0") + } + } + + @Test + fun `connection id null if zero`() { + assertEquals(ConnectionId.nullIfZero(ConnectionId.zero()), null) + assertTrue(ConnectionId.nullIfZero(ConnectionId.random()) != null) + } + + @Test + fun `connection id max value round trip`() { + // U128 max = 2^128 - 1 (all bits set) + val maxU128 = BigInteger.ONE.shl(128) - BigInteger.ONE + val id = ConnectionId(maxU128) + val decoded = encodeDecode({ id.encode(it) }, { ConnectionId.decode(it) }) + assertEquals(id, decoded) + assertEquals("f".repeat(32), decoded.toHexString()) + } + + @Test + fun `connection id high bit set round trip`() { + // Value with MSB set — tests BigInteger sign handling + val highBit = BigInteger.ONE.shl(127) + val id = ConnectionId(highBit) + val decoded = encodeDecode({ id.encode(it) }, { ConnectionId.decode(it) }) + assertEquals(id, decoded) + } + + // ---- Identity ---- + + @Test + fun `identity round trip`() { + val id = Identity(BigInteger.parseString("12345678901234567890")) + val decoded = encodeDecode({ id.encode(it) }, { Identity.decode(it) }) + assertEquals(id, decoded) + } + + @Test + fun `identity zero`() { + val zero = Identity.zero() + val decoded = encodeDecode({ zero.encode(it) }, { Identity.decode(it) }) + assertEquals(zero, decoded) + } + + @Test + fun `identity hex round trip`() { + val id = Identity(BigInteger.parseString("999888777666555444333222111")) + val hex = id.toHexString() + assertEquals(64, hex.length, "Identity hex should be 64 chars (32 bytes)") + val restored = Identity.fromHexString(hex) + assertEquals(id, restored) + } + + @Test + fun `identity to byte array is little endian`() { + val id = Identity(BigInteger.ONE) + val bytes = id.toByteArray() + assertEquals(32, bytes.size) + assertEquals(1.toByte(), bytes[0]) + for (i in 1 until 32) { + assertEquals(0.toByte(), bytes[i], "Byte at index $i should be 0") + } + } + + @Test + fun `identity max value round trip`() { + // U256 max = 2^256 - 1 (all bits set) + val maxU256 = BigInteger.ONE.shl(256) - BigInteger.ONE + val id = Identity(maxU256) + val decoded = encodeDecode({ id.encode(it) }, { Identity.decode(it) }) + assertEquals(id, decoded) + assertEquals("f".repeat(64), decoded.toHexString()) + } + + @Test + fun `identity high bit set round trip`() { + // Value with MSB set — tests BigInteger sign handling + val highBit = BigInteger.ONE.shl(255) + val id = Identity(highBit) + val decoded = encodeDecode({ id.encode(it) }, { Identity.decode(it) }) + assertEquals(id, decoded) + } + + @Test + fun `identity compare to ordering`() { + val small = Identity(BigInteger.ONE) + val large = Identity(BigInteger.parseString("999999999999999999999999999")) + assertTrue(small < large) + assertTrue(large > small) + assertEquals(0, small.compareTo(small)) + } + + // ---- Timestamp ---- + + @Test + fun `timestamp round trip`() { + val ts = Timestamp.fromEpochMicroseconds(1_700_000_000_000_000L) + val decoded = encodeDecode({ ts.encode(it) }, { Timestamp.decode(it) }) + assertEquals(ts, decoded) + } + + @Test + fun `timestamp epoch`() { + val epoch = Timestamp.UNIX_EPOCH + assertEquals(0L, epoch.microsSinceUnixEpoch) + val decoded = encodeDecode({ epoch.encode(it) }, { Timestamp.decode(it) }) + assertEquals(epoch, decoded) + } + + @Test + fun `timestamp negative round trip`() { + // 1969-12-31T23:59:59.000000Z — 1 second before epoch + val ts = Timestamp.fromEpochMicroseconds(-1_000_000L) + val decoded = encodeDecode({ ts.encode(it) }, { Timestamp.decode(it) }) + assertEquals(ts, decoded) + assertEquals(-1_000_000L, decoded.microsSinceUnixEpoch) + } + + @Test + fun `timestamp negative with micros round trip`() { + // Fractional negative: -0.5 seconds = -500_000 micros + val ts = Timestamp.fromEpochMicroseconds(-500_000L) + val decoded = encodeDecode({ ts.encode(it) }, { Timestamp.decode(it) }) + assertEquals(ts, decoded) + assertEquals(-500_000L, decoded.microsSinceUnixEpoch) + } + + @Test + fun `timestamp plus minus duration`() { + val ts = Timestamp.fromEpochMicroseconds(1_000_000L) // 1 second + val dur = TimeDuration(500_000.microseconds) // 0.5 seconds + val later = ts + dur + assertEquals(1_500_000L, later.microsSinceUnixEpoch) + val earlier = later - dur + assertEquals(ts, earlier) + } + + @Test + fun `timestamp difference`() { + val ts1 = Timestamp.fromEpochMicroseconds(3_000_000L) + val ts2 = Timestamp.fromEpochMicroseconds(1_000_000L) + val diff = ts1 - ts2 + assertEquals(2_000_000L, diff.micros) + } + + @Test + fun `timestamp comparison`() { + val earlier = Timestamp.fromEpochMicroseconds(100L) + val later = Timestamp.fromEpochMicroseconds(200L) + assertTrue(earlier < later) + assertTrue(later > earlier) + } + + @Test + fun `timestamp to iso string epoch`() { + assertEquals("1970-01-01T00:00:00.000000Z", Timestamp.UNIX_EPOCH.toISOString()) + } + + @Test + fun `timestamp to iso string pre epoch`() { + // 1 second before epoch + val ts = Timestamp.fromEpochMicroseconds(-1_000_000L) + assertEquals("1969-12-31T23:59:59.000000Z", ts.toISOString()) + } + + @Test + fun `timestamp to iso string pre epoch fractional`() { + // 0.5 seconds before epoch + val ts = Timestamp.fromEpochMicroseconds(-500_000L) + assertEquals("1969-12-31T23:59:59.500000Z", ts.toISOString()) + } + + @Test + fun `timestamp to iso string known date`() { + // 2023-11-14T22:13:20.000000Z = 1_700_000_000_000_000 micros + val ts = Timestamp.fromEpochMicroseconds(1_700_000_000_000_000L) + assertEquals("2023-11-14T22:13:20.000000Z", ts.toISOString()) + } + + @Test + fun `timestamp to iso string microsecond precision`() { + // 1 second + 123456 microseconds + val ts = Timestamp.fromEpochMicroseconds(1_123_456L) + assertEquals("1970-01-01T00:00:01.123456Z", ts.toISOString()) + } + + @Test + fun `timestamp to iso string pads leading zeros`() { + // 1 second + 7 microseconds — should pad to 6 digits + val ts = Timestamp.fromEpochMicroseconds(1_000_007L) + assertEquals("1970-01-01T00:00:01.000007Z", ts.toISOString()) + } + + @Test + fun `timestamp to string matches to iso string`() { + val ts = Timestamp.fromEpochMicroseconds(1_700_000_000_123_456L) + assertEquals(ts.toISOString(), ts.toString()) + } + + // ---- TimeDuration ---- + + @Test + fun `time duration round trip`() { + val dur = TimeDuration(123_456.microseconds) + val decoded = encodeDecode({ dur.encode(it) }, { TimeDuration.decode(it) }) + assertEquals(dur, decoded) + } + + @Test + fun `time duration arithmetic`() { + val a = TimeDuration(1.seconds) + val b = TimeDuration(500.milliseconds) + val sum = a + b + assertEquals(1_500_000L, sum.micros) + val diff = a - b + assertEquals(500_000L, diff.micros) + } + + @Test + fun `time duration comparison`() { + val shorter = TimeDuration(100.milliseconds) + val longer = TimeDuration(200.milliseconds) + assertTrue(shorter < longer) + } + + @Test + fun `time duration from millis`() { + val dur = TimeDuration.fromMillis(500) + assertEquals(500L, dur.millis) + assertEquals(500_000L, dur.micros) + } + + @Test + fun `time duration to string`() { + val positive = TimeDuration(5_123_456.microseconds) + assertEquals("+5.123456", positive.toString()) + + val negative = TimeDuration((-2_000_000).microseconds) + assertEquals("-2.000000", negative.toString()) + } + + // ---- ScheduleAt ---- + + @Test + fun `schedule at interval round trip`() { + val interval = ScheduleAt.interval(5.seconds) + val decoded = encodeDecode({ interval.encode(it) }, { ScheduleAt.decode(it) }) + assertTrue(decoded is ScheduleAt.Interval) + assertEquals((interval as ScheduleAt.Interval).duration, decoded.duration) + } + + @Test + fun `schedule at time round trip`() { + val ts = Timestamp.fromEpochMicroseconds(1_700_000_000_000_000L) + val time = ScheduleAt.Time(ts) + val decoded = encodeDecode({ time.encode(it) }, { ScheduleAt.decode(it) }) + assertTrue(decoded is ScheduleAt.Time) + assertEquals(ts, decoded.timestamp) + } + + // ---- SpacetimeUuid ---- + + @Test + fun `spacetime uuid round trip`() { + val uuid = SpacetimeUuid.random() + val decoded = encodeDecode({ uuid.encode(it) }, { SpacetimeUuid.decode(it) }) + assertEquals(uuid, decoded) + } + + @Test + fun `spacetime uuid nil`() { + assertEquals(UuidVersion.Nil, SpacetimeUuid.NIL.getVersion()) + } + + @Test + fun `spacetime uuid v4 detection`() { + // Build a V4 UUID from known bytes + val bytes = ByteArray(16) { 0x42 } + val v4 = SpacetimeUuid.fromRandomBytesV4(bytes) + assertEquals(UuidVersion.V4, v4.getVersion()) + } + + @Test + fun `spacetime uuid v7 detection`() { + val counter = Counter() + val ts = Timestamp.fromEpochMicroseconds(1_700_000_000_000_000L) + val randomBytes = byteArrayOf(0x01, 0x02, 0x03, 0x04) + val v7 = SpacetimeUuid.fromCounterV7(counter, ts, randomBytes) + assertEquals(UuidVersion.V7, v7.getVersion()) + } + + @Test + fun `spacetime uuid v7 counter extraction`() { + val counter = Counter() + val ts = Timestamp.fromEpochMicroseconds(1_700_000_000_000_000L) + val randomBytes = byteArrayOf(0x01, 0x02, 0x03, 0x04) + + val uuid0 = SpacetimeUuid.fromCounterV7(counter, ts, randomBytes) + assertEquals(0, uuid0.getCounter()) + + val uuid1 = SpacetimeUuid.fromCounterV7(counter, ts, randomBytes) + assertEquals(1, uuid1.getCounter()) + } + + @Test + fun `spacetime uuid compare to ordering`() { + val a = SpacetimeUuid.parse("00000000-0000-0000-0000-000000000001") + val b = SpacetimeUuid.parse("00000000-0000-0000-0000-000000000002") + assertTrue(a < b) + assertEquals(0, a.compareTo(a)) + } + + @Test + fun `spacetime uuid v7 timestamp encoding`() { + val counter = Counter() + // 1_700_000_000_000_000 microseconds = 1_700_000_000_000 ms + val ts = Timestamp.fromEpochMicroseconds(1_700_000_000_000_000L) + val randomBytes = byteArrayOf(0x01, 0x02, 0x03, 0x04) + val uuid = SpacetimeUuid.fromCounterV7(counter, ts, randomBytes) + val b = uuid.toByteArray() + + // Extract 48-bit timestamp from bytes 0-5 (big-endian) + val tsMs = (b[0].toLong() and 0xFF shl 40) or + (b[1].toLong() and 0xFF shl 32) or + (b[2].toLong() and 0xFF shl 24) or + (b[3].toLong() and 0xFF shl 16) or + (b[4].toLong() and 0xFF shl 8) or + (b[5].toLong() and 0xFF) + assertEquals(1_700_000_000_000L, tsMs) + } + + @Test + fun `spacetime uuid v7 version and variant bits`() { + val counter = Counter() + val ts = Timestamp.fromEpochMicroseconds(1_700_000_000_000_000L) + val randomBytes = byteArrayOf(0x01, 0x02, 0x03, 0x04) + val uuid = SpacetimeUuid.fromCounterV7(counter, ts, randomBytes) + val b = uuid.toByteArray() + + // Byte 6 high nibble must be 0x7 (version 7) + assertEquals(0x07, (b[6].toInt() shr 4) and 0x0F) + // Byte 8 high 2 bits must be 0b10 (variant RFC 4122) + assertEquals(0x02, (b[8].toInt() shr 6) and 0x03) + } + + @Test + fun `spacetime uuid v7 counter wraparound`() { + // Counter wraps at 0x7FFF_FFFF + val counter = Counter(0x7FFF_FFFE) + val ts = Timestamp.fromEpochMicroseconds(1_700_000_000_000_000L) + val randomBytes = byteArrayOf(0x01, 0x02, 0x03, 0x04) + + val uuid1 = SpacetimeUuid.fromCounterV7(counter, ts, randomBytes) + assertEquals(0x7FFF_FFFE, uuid1.getCounter()) + + val uuid2 = SpacetimeUuid.fromCounterV7(counter, ts, randomBytes) + assertEquals(0x7FFF_FFFF, uuid2.getCounter()) + + // Next increment wraps to 0 + val uuid3 = SpacetimeUuid.fromCounterV7(counter, ts, randomBytes) + assertEquals(0, uuid3.getCounter()) + } + + @Test + fun `spacetime uuid v7 round trip`() { + val counter = Counter() + val ts = Timestamp.fromEpochMicroseconds(1_700_000_000_000_000L) + val randomBytes = byteArrayOf(0x01, 0x02, 0x03, 0x04) + val uuid = SpacetimeUuid.fromCounterV7(counter, ts, randomBytes) + val decoded = encodeDecode({ uuid.encode(it) }, { SpacetimeUuid.decode(it) }) + assertEquals(uuid, decoded) + } + + // ---- Int128 ---- + + @Test + fun `int128 round trip`() { + val v = Int128(BigInteger.parseString("170141183460469231731687303715884105727")) // 2^127 - 1 + val decoded = encodeDecode({ v.encode(it) }, { Int128.decode(it) }) + assertEquals(v, decoded) + } + + @Test + fun `int128 zero round trip`() { + val v = Int128(BigInteger.ZERO) + val decoded = encodeDecode({ v.encode(it) }, { Int128.decode(it) }) + assertEquals(v, decoded) + } + + @Test + fun `int128 negative round trip`() { + val v = Int128(-BigInteger.ONE.shl(127)) // -2^127 (I128 min) + val decoded = encodeDecode({ v.encode(it) }, { Int128.decode(it) }) + assertEquals(v, decoded) + } + + @Test + fun `int128 compare to ordering`() { + val neg = Int128(-BigInteger.ONE) + val zero = Int128(BigInteger.ZERO) + val pos = Int128(BigInteger.ONE) + assertTrue(neg < zero) + assertTrue(zero < pos) + assertEquals(0, zero.compareTo(zero)) + } + + @Test + fun `int128 to string`() { + val v = Int128(BigInteger.parseString("42")) + assertEquals("42", v.toString()) + } + + // ---- UInt128 ---- + + @Test + fun `uint128 round trip`() { + val v = UInt128(BigInteger.ONE.shl(128) - BigInteger.ONE) // 2^128 - 1 + val decoded = encodeDecode({ v.encode(it) }, { UInt128.decode(it) }) + assertEquals(v, decoded) + } + + @Test + fun `uint128 zero round trip`() { + val v = UInt128(BigInteger.ZERO) + val decoded = encodeDecode({ v.encode(it) }, { UInt128.decode(it) }) + assertEquals(v, decoded) + } + + @Test + fun `uint128 high bit set round trip`() { + val v = UInt128(BigInteger.ONE.shl(127)) + val decoded = encodeDecode({ v.encode(it) }, { UInt128.decode(it) }) + assertEquals(v, decoded) + } + + @Test + fun `uint128 compare to ordering`() { + val small = UInt128(BigInteger.ONE) + val large = UInt128(BigInteger.ONE.shl(100)) + assertTrue(small < large) + assertEquals(0, small.compareTo(small)) + } + + // ---- Int256 ---- + + @Test + fun `int256 round trip`() { + val v = Int256(BigInteger.ONE.shl(255) - BigInteger.ONE) // 2^255 - 1 (I256 max) + val decoded = encodeDecode({ v.encode(it) }, { Int256.decode(it) }) + assertEquals(v, decoded) + } + + @Test + fun `int256 zero round trip`() { + val v = Int256(BigInteger.ZERO) + val decoded = encodeDecode({ v.encode(it) }, { Int256.decode(it) }) + assertEquals(v, decoded) + } + + @Test + fun `int256 negative round trip`() { + val v = Int256(-BigInteger.ONE.shl(255)) // -2^255 (I256 min) + val decoded = encodeDecode({ v.encode(it) }, { Int256.decode(it) }) + assertEquals(v, decoded) + } + + @Test + fun `int256 compare to ordering`() { + val neg = Int256(-BigInteger.ONE) + val pos = Int256(BigInteger.ONE) + assertTrue(neg < pos) + } + + // ---- UInt256 ---- + + @Test + fun `uint256 round trip`() { + val v = UInt256(BigInteger.ONE.shl(256) - BigInteger.ONE) // 2^256 - 1 + val decoded = encodeDecode({ v.encode(it) }, { UInt256.decode(it) }) + assertEquals(v, decoded) + } + + @Test + fun `uint256 zero round trip`() { + val v = UInt256(BigInteger.ZERO) + val decoded = encodeDecode({ v.encode(it) }, { UInt256.decode(it) }) + assertEquals(v, decoded) + } + + @Test + fun `uint256 high bit set round trip`() { + val v = UInt256(BigInteger.ONE.shl(255)) + val decoded = encodeDecode({ v.encode(it) }, { UInt256.decode(it) }) + assertEquals(v, decoded) + } + + // ---- SpacetimeResult ---- + + @Test + fun `spacetime result ok round trip`() { + val writer = BsatnWriter() + // Encode: tag 0 + I32 + writer.writeSumTag(0u) + writer.writeI32(42) + val reader = BsatnReader(writer.toByteArray()) + val tag = reader.readSumTag().toInt() + assertEquals(0, tag) + val value = reader.readI32() + assertEquals(42, value) + assertEquals(0, reader.remaining) + } + + @Test + fun `spacetime result err round trip`() { + val writer = BsatnWriter() + // Encode: tag 1 + String + writer.writeSumTag(1u) + writer.writeString("oops") + val reader = BsatnReader(writer.toByteArray()) + val tag = reader.readSumTag().toInt() + assertEquals(1, tag) + val error = reader.readString() + assertEquals("oops", error) + assertEquals(0, reader.remaining) + } + + @Test + fun `spacetime result ok type`() { + val result: SpacetimeResult = SpacetimeResult.Ok(42) + assertIs>(result) + assertEquals(42, result.value) + } + + @Test + fun `spacetime result err type`() { + val result: SpacetimeResult = SpacetimeResult.Err("oops") + assertIs>(result) + assertEquals("oops", result.error) + } + + @Test + fun `spacetime result when exhaustive`() { + val ok: SpacetimeResult = SpacetimeResult.Ok(1) + val err: SpacetimeResult = SpacetimeResult.Err("e") + // Verify exhaustive when works (sealed interface) + val okMsg = when (ok) { + is SpacetimeResult.Ok -> "ok:${ok.value}" + is SpacetimeResult.Err -> "err:${ok.error}" + } + assertEquals("ok:1", okMsg) + val errMsg = when (err) { + is SpacetimeResult.Ok -> "ok:${err.value}" + is SpacetimeResult.Err -> "err:${err.error}" + } + assertEquals("err:e", errMsg) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/UtilTest.kt b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/UtilTest.kt new file mode 100644 index 00000000000..8348c844158 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/commonTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/UtilTest.kt @@ -0,0 +1,90 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.time.Instant + +class UtilTest { + // ---- BigInteger hex round-trip ---- + + @Test + fun `hex round trip 16 bytes`() { + val value = BigInteger.parseString("12345678901234567890abcdef", 16) + val hex = value.toHexString(16) // 16 bytes = 32 hex chars + assertEquals(32, hex.length) + val restored = parseHexString(hex) + assertEquals(value, restored) + } + + @Test + fun `hex round trip 32 bytes`() { + val value = BigInteger.parseString("abcdef0123456789abcdef0123456789", 16) + val hex = value.toHexString(32) // 32 bytes = 64 hex chars + assertEquals(64, hex.length) + val restored = parseHexString(hex) + assertEquals(value, restored) + } + + @Test + fun `hex zero value`() { + val zero = BigInteger.ZERO + val hex16 = zero.toHexString(16) + assertEquals("00000000000000000000000000000000", hex16) + assertEquals(BigInteger.ZERO, parseHexString(hex16)) + + val hex32 = zero.toHexString(32) + assertEquals("0000000000000000000000000000000000000000000000000000000000000000", hex32) + assertEquals(BigInteger.ZERO, parseHexString(hex32)) + } + + // ---- Instant microsecond round-trip ---- + + @Test + fun `instant microsecond round trip`() { + val micros = 1_700_000_000_123_456L + val instant = Instant.fromEpochMicroseconds(micros) + val roundTripped = instant.toEpochMicroseconds() + assertEquals(micros, roundTripped) + } + + @Test + fun `instant microsecond zero`() { + val instant = Instant.fromEpochMicroseconds(0L) + assertEquals(0L, instant.toEpochMicroseconds()) + } + + @Test + fun `instant microsecond negative`() { + val micros = -1_000_000L // 1 second before epoch + val instant = Instant.fromEpochMicroseconds(micros) + assertEquals(micros, instant.toEpochMicroseconds()) + } + + @Test + fun `instant microsecond max round trips`() { + val micros = Long.MAX_VALUE + val instant = Instant.fromEpochMicroseconds(micros) + assertEquals(micros, instant.toEpochMicroseconds()) + } + + @Test + fun `instant microsecond min round trips`() { + // Long.MIN_VALUE doesn't land on an exact second boundary, so + // floorDiv pushes it one second beyond the representable range. + // Use the actual minimum that round-trips cleanly. + val minSeconds = Long.MIN_VALUE / 1_000_000L + val micros = minSeconds * 1_000_000L + val instant = Instant.fromEpochMicroseconds(micros) + assertEquals(micros, instant.toEpochMicroseconds()) + } + + @Test + fun `instant beyond microsecond range throws`() { + // An Instant far beyond the I64 microsecond wire format range + val farFuture = Instant.fromEpochSeconds(Long.MAX_VALUE / 1_000_000L + 1) + assertFailsWith { + farFuture.toEpochMicroseconds() + } + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/jvmMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/Compression.jvm.kt b/sdks/kotlin/spacetimedb-sdk/src/jvmMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/Compression.jvm.kt new file mode 100644 index 00000000000..dcccafaebb3 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/jvmMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/Compression.jvm.kt @@ -0,0 +1,33 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.CompressionMode +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.util.zip.GZIPInputStream +import org.brotli.dec.BrotliInputStream + +internal actual fun decompressMessage(data: ByteArray): DecompressedPayload { + require(data.isNotEmpty()) { "Empty message" } + + return when (val tag = data[0]) { + Compression.NONE -> DecompressedPayload(data, offset = 1) + Compression.BROTLI -> { + val input = BrotliInputStream(ByteArrayInputStream(data, 1, data.size - 1)) + val output = ByteArrayOutputStream() + input.use { it.copyTo(output) } + DecompressedPayload(output.toByteArray()) + } + Compression.GZIP -> { + val input = GZIPInputStream(ByteArrayInputStream(data, 1, data.size - 1)) + val output = ByteArrayOutputStream() + input.use { it.copyTo(output) } + DecompressedPayload(output.toByteArray()) + } + else -> error("Unknown compression tag: $tag") + } +} + +internal actual val defaultCompressionMode: CompressionMode = CompressionMode.GZIP + +internal actual val availableCompressionModes: Set = + setOf(CompressionMode.NONE, CompressionMode.BROTLI, CompressionMode.GZIP) diff --git a/sdks/kotlin/spacetimedb-sdk/src/jvmTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/CallbackDispatcherTest.kt b/sdks/kotlin/spacetimedb-sdk/src/jvmTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/CallbackDispatcherTest.kt new file mode 100644 index 00000000000..65c339e33fe --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/jvmTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/CallbackDispatcherTest.kt @@ -0,0 +1,64 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ServerMessage +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.DelicateCoroutinesApi +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.newSingleThreadContext +import kotlinx.coroutines.test.StandardTestDispatcher +import kotlinx.coroutines.test.advanceUntilIdle +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +@OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class) +class CallbackDispatcherTest { + + private val testIdentity = Identity(BigInteger.ONE) + private val testConnectionId = ConnectionId(BigInteger.TWO) + private val testToken = "test-token-abc" + + private fun initialConnectionMsg() = ServerMessage.InitialConnection( + identity = testIdentity, + connectionId = testConnectionId, + token = testToken, + ) + + @Test + fun `callback dispatcher is used for callbacks`() = runTest { + val transport = FakeTransport() + + val callbackDispatcher = newSingleThreadContext("TestCallbackThread") + val callbackThreadDeferred = CompletableDeferred() + + callbackDispatcher.use { callbackDispatcher -> + val conn = DbConnection( + transport = transport, + scope = CoroutineScope(SupervisorJob() + StandardTestDispatcher(testScheduler)), + onConnectCallbacks = listOf { _, _, _ -> + callbackThreadDeferred.complete(Thread.currentThread().name) + }, + onDisconnectCallbacks = emptyList(), + onConnectErrorCallbacks = emptyList(), + clientConnectionId = ConnectionId.random(), + stats = Stats(), + moduleDescriptor = null, + callbackDispatcher = callbackDispatcher, + ) + conn.connect() + transport.sendToClient(initialConnectionMsg()) + advanceUntilIdle() + + val capturedThread = callbackThreadDeferred.await() + advanceUntilIdle() + assertNotNull(capturedThread) + assertTrue(capturedThread.contains("TestCallbackThread")) + conn.disconnect() + } + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/jvmTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ConcurrencyStressTest.kt b/sdks/kotlin/spacetimedb-sdk/src/jvmTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ConcurrencyStressTest.kt new file mode 100644 index 00000000000..d5f429755bb --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/jvmTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/ConcurrencyStressTest.kt @@ -0,0 +1,945 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.ServerMessage +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.TableUpdateRows +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.ConnectionId +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import java.util.concurrent.CyclicBarrier +import java.util.concurrent.atomic.AtomicInteger +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertSame +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.milliseconds + +/** + * Concurrency stress tests for the lock-free data structures in the SDK. + * These run on JVM with real threads (Dispatchers.Default) to exercise + * CAS loops and atomic operations under actual contention. + */ +class ConcurrencyStressTest { + + companion object { + private const val THREAD_COUNT = 16 + private const val OPS_PER_THREAD = 500 + } + + // ---- TableCache: concurrent inserts ---- + + @Test + fun `table cache concurrent inserts are not lost`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val totalRows = THREAD_COUNT * OPS_PER_THREAD + val barrier = CyclicBarrier(THREAD_COUNT) + + coroutineScope { + repeat(THREAD_COUNT) { threadIdx -> + launch { + barrier.await() + val start = threadIdx * OPS_PER_THREAD + for (i in start until start + OPS_PER_THREAD) { + val row = SampleRow(i, "row-$i") + cache.applyInserts(STUB_CTX, buildRowList(row.encode())) + } + } + } + } + + assertEquals(totalRows, cache.count()) + val allIds = cache.all().map { it.id }.toSet() + assertEquals(totalRows, allIds.size) + for (i in 0 until totalRows) { + assertTrue(i in allIds, "Missing row id=$i") + } + } + + // ---- TableCache: concurrent inserts and deletes ---- + + @Test + fun `table cache concurrent insert and delete converges`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val barrier = CyclicBarrier(THREAD_COUNT) + + // Pre-insert rows that will be deleted + val deleteRange = 0 until (THREAD_COUNT / 2) * OPS_PER_THREAD + for (i in deleteRange) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "pre-$i").encode())) + } + + coroutineScope { + // Half the threads insert new rows + repeat(THREAD_COUNT / 2) { threadIdx -> + launch { + barrier.await() + val base = deleteRange.last + 1 + threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "new-$i").encode())) + } + } + } + // Half the threads delete pre-inserted rows + repeat(THREAD_COUNT / 2) { threadIdx -> + launch { + barrier.await() + val start = threadIdx * OPS_PER_THREAD + for (i in start until start + OPS_PER_THREAD) { + val row = SampleRow(i, "pre-$i") + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed) + } + } + } + } + + // All pre-inserted rows should be gone, all new rows should exist + val insertedCount = (THREAD_COUNT / 2) * OPS_PER_THREAD + assertEquals(insertedCount, cache.count()) + for (row in cache.all()) { + assertTrue(row.name.startsWith("new-"), "Unexpected row: $row") + } + } + + // ---- TableCache: concurrent reads during writes ---- + + @Test + fun `table cache reads are consistent snapshots during writes`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val barrier = CyclicBarrier(THREAD_COUNT) + + coroutineScope { + // Writers + repeat(THREAD_COUNT / 2) { threadIdx -> + launch { + barrier.await() + val base = threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "row-$i").encode())) + } + } + } + // Readers: snapshot must always be self-consistent + repeat(THREAD_COUNT / 2) { _ -> + launch { + barrier.await() + repeat(OPS_PER_THREAD) { + val snapshot = cache.all() + cache.count() + // Snapshot is a point-in-time view — its size should be consistent + // (count() may differ since it reads a newer snapshot) + val ids = snapshot.map { it.id }.toSet() + assertEquals(snapshot.size, ids.size, "Snapshot contains duplicate IDs") + } + } + } + } + } + + // ---- TableCache: concurrent ref count increments and decrements ---- + + @Test + fun `table cache ref count survives concurrent increment decrement`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val sharedRow = SampleRow(42, "shared") + cache.applyInserts(STUB_CTX, buildRowList(sharedRow.encode())) + + val barrier = CyclicBarrier(THREAD_COUNT) + + // Each thread increments then decrements the refcount + coroutineScope { + repeat(THREAD_COUNT) { _ -> + launch { + barrier.await() + repeat(OPS_PER_THREAD) { + cache.applyInserts(STUB_CTX, buildRowList(sharedRow.encode())) + val parsed = cache.parseDeletes(buildRowList(sharedRow.encode())) + cache.applyDeletes(STUB_CTX, parsed) + } + } + } + } + + // After all increments + decrements, refcount should be back to 1 + assertEquals(1, cache.count()) + assertEquals(sharedRow, cache.all().single()) + } + + // ---- UniqueIndex: consistent with cache under concurrent mutations ---- + + @Test + fun `unique index stays consistent under concurrent inserts`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val index = UniqueIndex(cache) { it.id } + val totalRows = THREAD_COUNT * OPS_PER_THREAD + val barrier = CyclicBarrier(THREAD_COUNT) + + coroutineScope { + repeat(THREAD_COUNT) { threadIdx -> + launch { + barrier.await() + val base = threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "row-$i").encode())) + } + } + } + } + + // Every inserted row must be findable in the index + for (i in 0 until totalRows) { + val found = index.find(i) + assertEquals(SampleRow(i, "row-$i"), found, "Index missing row id=$i") + } + } + + @Test + fun `unique index stays consistent under concurrent inserts and deletes`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val index = UniqueIndex(cache) { it.id } + + // Pre-insert rows to delete + val deleteCount = (THREAD_COUNT / 2) * OPS_PER_THREAD + for (i in 0 until deleteCount) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "pre-$i").encode())) + } + + val barrier = CyclicBarrier(THREAD_COUNT) + coroutineScope { + // Inserters + repeat(THREAD_COUNT / 2) { threadIdx -> + launch { + barrier.await() + val base = deleteCount + threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "new-$i").encode())) + } + } + } + // Deleters + repeat(THREAD_COUNT / 2) { threadIdx -> + launch { + barrier.await() + val start = threadIdx * OPS_PER_THREAD + for (i in start until start + OPS_PER_THREAD) { + val row = SampleRow(i, "pre-$i") + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed) + } + } + } + } + + // Deleted rows gone from index + for (i in 0 until deleteCount) { + assertEquals(null, index.find(i), "Deleted row id=$i still in index") + } + // New rows present in index + val insertBase = deleteCount + val insertCount = (THREAD_COUNT / 2) * OPS_PER_THREAD + for (i in insertBase until insertBase + insertCount) { + assertEquals(SampleRow(i, "new-$i"), index.find(i), "Index missing new row id=$i") + } + } + + // ---- BTreeIndex: consistent under concurrent mutations ---- + + @Test + fun `btree index stays consistent under concurrent inserts`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + // Key on name — groups of rows share the same name + val groupCount = 10 + val index = BTreeIndex(cache) { it.name } + val barrier = CyclicBarrier(THREAD_COUNT) + + coroutineScope { + repeat(THREAD_COUNT) { threadIdx -> + launch { + barrier.await() + val base = threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + val groupName = "group-${i % groupCount}" + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, groupName).encode())) + } + } + } + } + + val totalRows = THREAD_COUNT * OPS_PER_THREAD + val expectedPerGroup = totalRows / groupCount + for (g in 0 until groupCount) { + val matches = index.filter("group-$g") + assertEquals(expectedPerGroup, matches.size, "Group group-$g count mismatch") + } + } + + // ---- Callback registration: concurrent add/remove during iteration ---- + + @Test + fun `callback registration survives concurrent add remove`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val callCount = AtomicInteger(0) + val barrier = CyclicBarrier(THREAD_COUNT) + + coroutineScope { + // Half the threads add and remove callbacks + repeat(THREAD_COUNT / 2) { _ -> + launch { + barrier.await() + repeat(OPS_PER_THREAD) { + val cb: (EventContext, SampleRow) -> Unit = { _, _ -> callCount.incrementAndGet() } + cache.onInsert(cb) + cache.removeOnInsert(cb) + } + } + } + // Other half trigger inserts (fires callbacks that are registered at snapshot time) + repeat(THREAD_COUNT / 2) { threadIdx -> + launch { + barrier.await() + val base = threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + val callbacks = cache.applyInserts( + STUB_CTX, buildRowList(SampleRow(i, "row-$i").encode()) + ) + callbacks.forEach { it.invoke() } + } + } + } + } + + // The test passes if no ConcurrentModificationException or lost update occurs. + // callCount can be anything (depends on timing), but count() must be exact. + val insertedCount = (THREAD_COUNT / 2) * OPS_PER_THREAD + assertEquals(insertedCount, cache.count()) + } + + // ---- ClientCache.getOrCreateTable: concurrent creation of same table ---- + + @Test + fun `client cache get or create table is idempotent under contention`() = runBlocking(Dispatchers.Default) { + val clientCache = ClientCache() + val barrier = CyclicBarrier(THREAD_COUNT) + val creationCount = AtomicInteger(0) + + val results = coroutineScope { + (0 until THREAD_COUNT).map { + async { + barrier.await() + clientCache.getOrCreateTable("players") { + creationCount.incrementAndGet() + TableCache.withPrimaryKey(::decodeSampleRow) { it.id } + } + } + }.awaitAll() + } + + // All threads must get the same instance + val first = results.first() + for (table in results) { + assertSame(first, table, "Different table instance returned by getOrCreateTable") + } + // Factory is called by each thread that misses the fast path (line 447). + // Threads arriving after the table is visible skip factory entirely. + // CAS retries never re-invoke factory — it's hoisted outside the loop. + // In practice most threads miss the fast path under contention, but at least 1 must create. + val count = creationCount.get() + assertTrue(count >= 1, "Factory must be called at least once, got: $count") + assertTrue(count <= THREAD_COUNT, "Factory called more than THREAD_COUNT times: $count") + } + + // ---- NetworkRequestTracker: concurrent start/finish ---- + + @Test + fun `network request tracker concurrent start finish`() = runBlocking(Dispatchers.Default) { + val tracker = NetworkRequestTracker() + val barrier = CyclicBarrier(THREAD_COUNT) + val totalOps = THREAD_COUNT * OPS_PER_THREAD + + coroutineScope { + repeat(THREAD_COUNT) { _ -> + launch { + barrier.await() + repeat(OPS_PER_THREAD) { + val id = tracker.startTrackingRequest("test") + tracker.finishTrackingRequest(id) + } + } + } + } + + assertEquals(totalOps, tracker.sampleCount) + assertEquals(0, tracker.requestsAwaitingResponse) + } + + @Test + fun `network request tracker concurrent insert sample`() = runBlocking(Dispatchers.Default) { + val tracker = NetworkRequestTracker() + val barrier = CyclicBarrier(THREAD_COUNT) + val totalOps = THREAD_COUNT * OPS_PER_THREAD + + coroutineScope { + repeat(THREAD_COUNT) { _ -> + launch { + barrier.await() + repeat(OPS_PER_THREAD) { i -> + tracker.insertSample((i + 1).milliseconds, "op-$i") + } + } + } + } + + assertEquals(totalOps, tracker.sampleCount) + // Min must be 1ms (smallest sample), max must be OPS_PER_THREAD ms + val result = assertNotNull(tracker.allTimeMinMax, "allTimeMinMax should not be null after $totalOps samples") + assertEquals(1.milliseconds, result.min.duration, "allTimeMin wrong: ${result.min}") + assertEquals(OPS_PER_THREAD.milliseconds, result.max.duration, "allTimeMax wrong: ${result.max}") + } + + // ---- Logger: concurrent level/handler read/write ---- + + @Test + fun `logger concurrent level and handler changes`() = runBlocking(Dispatchers.Default) { + val originalLevel = Logger.level + val originalHandler = Logger.handler + val barrier = CyclicBarrier(THREAD_COUNT) + val logCount = AtomicInteger(0) + + try { + coroutineScope { + // Half the threads toggle the log level + repeat(THREAD_COUNT / 2) { _ -> + launch { + barrier.await() + repeat(OPS_PER_THREAD) { i -> + Logger.level = if (i % 2 == 0) LogLevel.DEBUG else LogLevel.ERROR + } + } + } + // Other half swap the handler and log + repeat(THREAD_COUNT / 2) { _ -> + launch { + barrier.await() + repeat(OPS_PER_THREAD) { + Logger.handler = LogHandler { _, _ -> logCount.incrementAndGet() } + Logger.info { "stress" } + } + } + } + } + // No crash or exception = pass. logCount is non-deterministic. + } finally { + Logger.level = originalLevel + Logger.handler = originalHandler + } + } + + // ---- Internal listeners: concurrent listener fire during add ---- + + @Test + fun `internal listeners fire safely during concurrent registration`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val listenerCallCount = AtomicInteger(0) + val barrier = CyclicBarrier(THREAD_COUNT) + + coroutineScope { + // Half add listeners + repeat(THREAD_COUNT / 2) { _ -> + launch { + barrier.await() + repeat(OPS_PER_THREAD) { + cache.addInternalInsertListener { listenerCallCount.incrementAndGet() } + } + } + } + // Half do inserts (which fire all currently-registered listeners) + repeat(THREAD_COUNT / 2) { threadIdx -> + launch { + barrier.await() + val base = threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "r-$i").encode())) + } + } + } + } + + val insertedCount = (THREAD_COUNT / 2) * OPS_PER_THREAD + assertEquals(insertedCount, cache.count()) + // Listener calls >= 0, no crash = pass + assertTrue(listenerCallCount.get() >= 0) + } + + // ---- TableCache clear() racing with inserts ---- + + @Test + fun `table cache clear racing with inserts`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val barrier = CyclicBarrier(THREAD_COUNT) + + coroutineScope { + // One thread clears repeatedly + launch { + barrier.await() + repeat(OPS_PER_THREAD) { + cache.clear() + } + } + // Rest insert + repeat(THREAD_COUNT - 1) { threadIdx -> + launch { + barrier.await() + val base = threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "r-$i").encode())) + } + } + } + } + + // The final state depends on timing, but the cache must be internally consistent: + // count() == all().size, no duplicates in all() + val all = cache.all() + assertEquals(cache.count(), all.size) + val ids = all.map { it.id }.toSet() + assertEquals(all.size, ids.size, "Duplicate IDs after clear/insert race") + } + + // ---- UniqueIndex: reads during concurrent mutations ---- + + @Test + fun `unique index reads return consistent snapshots during mutations`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val index = UniqueIndex(cache) { it.id } + val barrier = CyclicBarrier(THREAD_COUNT) + + coroutineScope { + // Writers + repeat(THREAD_COUNT / 2) { threadIdx -> + launch { + barrier.await() + val base = threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "r-$i").encode())) + } + } + } + // Readers + repeat(THREAD_COUNT / 2) { _ -> + launch { + barrier.await() + repeat(OPS_PER_THREAD * 2) { i -> + val row = index.find(i) + // If found, it must be consistent + if (row != null) { + assertEquals(i, row.id, "Index returned wrong row for key=$i") + assertEquals("r-$i", row.name) + } + } + } + } + } + } + + // ---- BTreeIndex: concurrent insert/delete with group verification ---- + + @Test + fun `btree index group count converges after concurrent insert delete`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val index = BTreeIndex(cache) { it.name } + val groupName = "shared-group" + + // Pre-insert rows to delete + val deleteCount = (THREAD_COUNT / 2) * OPS_PER_THREAD + for (i in 0 until deleteCount) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, groupName).encode())) + } + + val barrier = CyclicBarrier(THREAD_COUNT) + coroutineScope { + // Insert new rows with same group + repeat(THREAD_COUNT / 2) { threadIdx -> + launch { + barrier.await() + val base = deleteCount + threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, groupName).encode())) + } + } + } + // Delete pre-inserted rows + repeat(THREAD_COUNT / 2) { threadIdx -> + launch { + barrier.await() + val start = threadIdx * OPS_PER_THREAD + for (i in start until start + OPS_PER_THREAD) { + val row = SampleRow(i, groupName) + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed) + } + } + } + } + + val expectedCount = (THREAD_COUNT / 2) * OPS_PER_THREAD + val groupRows = index.filter(groupName) + assertEquals(expectedCount, groupRows.size, "BTreeIndex group count mismatch") + // Verify only new rows remain + for (row in groupRows) { + assertTrue(row.id >= deleteCount, "Deleted row still in BTreeIndex: $row") + } + } + + // ---- DbConnection: concurrent disconnect from multiple threads ---- + + @Test + fun `concurrent disconnect fires callback exactly once`() = runBlocking(Dispatchers.Default) { + val transport = FakeTransport() + val disconnectCount = AtomicInteger(0) + + val conn = DbConnection( + transport = transport, + scope = CoroutineScope(SupervisorJob() + Dispatchers.Default), + onConnectCallbacks = emptyList(), + onDisconnectCallbacks = listOf { _, _ -> disconnectCount.incrementAndGet() }, + onConnectErrorCallbacks = emptyList(), + clientConnectionId = ConnectionId.random(), + stats = Stats(), + moduleDescriptor = null, + callbackDispatcher = null, + ) + conn.connect() + transport.sendToClient( + ServerMessage.InitialConnection( + identity = Identity(BigInteger.ONE), + connectionId = ConnectionId(BigInteger.TWO), + token = "token", + ) + ) + // Give the receive loop time to process the initial connection + kotlinx.coroutines.delay(100) + assertTrue(conn.isActive) + + val barrier = CyclicBarrier(THREAD_COUNT) + coroutineScope { + repeat(THREAD_COUNT) { + launch { + barrier.await() + conn.disconnect() + } + } + } + + assertFalse(conn.isActive) + assertEquals(1, disconnectCount.get(), "onDisconnect must fire exactly once") + } + + // ---- TableCache: concurrent updates (combined delete+insert) ---- + + @Test + fun `table cache concurrent updates replace correctly`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val totalRows = THREAD_COUNT * OPS_PER_THREAD + // Pre-insert all rows with original names + for (i in 0 until totalRows) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "original-$i").encode())) + } + assertEquals(totalRows, cache.count()) + + val barrier = CyclicBarrier(THREAD_COUNT) + coroutineScope { + repeat(THREAD_COUNT) { threadIdx -> + launch { + barrier.await() + val base = threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + val oldRow = SampleRow(i, "original-$i") + val newRow = SampleRow(i, "updated-$i") + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(newRow.encode()), + deletes = buildRowList(oldRow.encode()), + ) + val parsed = cache.parseUpdate(update) + cache.applyUpdate(STUB_CTX, parsed) + } + } + } + } + + // All rows should be updated, count unchanged + assertEquals(totalRows, cache.count()) + for (row in cache.all()) { + assertTrue(row.name.startsWith("updated-"), "Row not updated: $row") + } + } + + // ---- TableCache: two-phase deletes under contention ---- + + @Test + fun `two phase deletes under contention`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val totalRows = THREAD_COUNT * OPS_PER_THREAD + for (i in 0 until totalRows) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "row-$i").encode())) + } + + val beforeDeleteCount = AtomicInteger(0) + val deleteCount = AtomicInteger(0) + cache.onBeforeDelete { _, _ -> beforeDeleteCount.incrementAndGet() } + cache.onDelete { _, _ -> deleteCount.incrementAndGet() } + + val barrier = CyclicBarrier(THREAD_COUNT) + coroutineScope { + repeat(THREAD_COUNT) { threadIdx -> + launch { + barrier.await() + val base = threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + val row = SampleRow(i, "row-$i") + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.preApplyDeletes(STUB_CTX, parsed) + val callbacks = cache.applyDeletes(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + } + } + } + } + + assertEquals(0, cache.count()) + assertEquals(totalRows, beforeDeleteCount.get()) + assertEquals(totalRows, deleteCount.get()) + } + + // ---- TableCache: two-phase updates under contention ---- + + @Test + fun `two phase updates under contention`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val totalRows = THREAD_COUNT * OPS_PER_THREAD + for (i in 0 until totalRows) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "v0-$i").encode())) + } + + val updateCount = AtomicInteger(0) + cache.onUpdate { _, _, _ -> updateCount.incrementAndGet() } + + val barrier = CyclicBarrier(THREAD_COUNT) + coroutineScope { + repeat(THREAD_COUNT) { threadIdx -> + launch { + barrier.await() + val base = threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + val oldRow = SampleRow(i, "v0-$i") + val newRow = SampleRow(i, "v1-$i") + val update = TableUpdateRows.PersistentTable( + inserts = buildRowList(newRow.encode()), + deletes = buildRowList(oldRow.encode()), + ) + val parsed = cache.parseUpdate(update) + cache.preApplyUpdate(STUB_CTX, parsed) + val callbacks = cache.applyUpdate(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + } + } + } + } + + assertEquals(totalRows, cache.count()) + assertEquals(totalRows, updateCount.get()) + for (row in cache.all()) { + assertTrue(row.name.startsWith("v1-"), "Row not updated: $row") + } + } + + // ---- Content-key table: concurrent operations without primary key ---- + + @Test + fun `content key table concurrent inserts`() = runBlocking(Dispatchers.Default) { + val cache = TableCache.withContentKey(::decodeSampleRow) + val totalRows = THREAD_COUNT * OPS_PER_THREAD + val barrier = CyclicBarrier(THREAD_COUNT) + + coroutineScope { + repeat(THREAD_COUNT) { threadIdx -> + launch { + barrier.await() + val base = threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "row-$i").encode())) + } + } + } + } + + assertEquals(totalRows, cache.count()) + val allIds = cache.all().map { it.id }.toSet() + assertEquals(totalRows, allIds.size) + } + + @Test + fun `content key table concurrent insert and delete`() = runBlocking(Dispatchers.Default) { + val cache = TableCache.withContentKey(::decodeSampleRow) + + // Pre-insert rows to delete + val deleteCount = (THREAD_COUNT / 2) * OPS_PER_THREAD + for (i in 0 until deleteCount) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "pre-$i").encode())) + } + + val barrier = CyclicBarrier(THREAD_COUNT) + coroutineScope { + repeat(THREAD_COUNT / 2) { threadIdx -> + launch { + barrier.await() + val base = deleteCount + threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "new-$i").encode())) + } + } + } + repeat(THREAD_COUNT / 2) { threadIdx -> + launch { + barrier.await() + val start = threadIdx * OPS_PER_THREAD + for (i in start until start + OPS_PER_THREAD) { + val row = SampleRow(i, "pre-$i") + val parsed = cache.parseDeletes(buildRowList(row.encode())) + cache.applyDeletes(STUB_CTX, parsed) + } + } + } + } + + val expectedCount = (THREAD_COUNT / 2) * OPS_PER_THREAD + assertEquals(expectedCount, cache.count()) + for (row in cache.all()) { + assertTrue(row.name.startsWith("new-"), "Unexpected row: $row") + } + } + + // ---- Event table: concurrent fire-and-forget ---- + + @Test + fun `event table concurrent updates never store rows`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val insertCallbackCount = AtomicInteger(0) + cache.onInsert { _, _ -> insertCallbackCount.incrementAndGet() } + + val barrier = CyclicBarrier(THREAD_COUNT) + coroutineScope { + repeat(THREAD_COUNT) { threadIdx -> + launch { + barrier.await() + val base = threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + val event = TableUpdateRows.EventTable( + events = buildRowList(SampleRow(i, "evt-$i").encode()), + ) + val parsed = cache.parseUpdate(event) + val callbacks = cache.applyUpdate(STUB_CTX, parsed) + callbacks.forEach { it.invoke() } + } + } + } + } + + // Event rows must never persist + assertEquals(0, cache.count()) + // Every event should have fired a callback + assertEquals(THREAD_COUNT * OPS_PER_THREAD, insertCallbackCount.get()) + } + + // ---- Index construction from pre-populated cache under contention ---- + + @Test + fun `index construction during concurrent inserts`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val totalRows = THREAD_COUNT * OPS_PER_THREAD + val barrier = CyclicBarrier(THREAD_COUNT + 1) // +1 for index builder + + val indices = mutableListOf>() + + coroutineScope { + // Inserters + repeat(THREAD_COUNT) { threadIdx -> + launch { + barrier.await() + val base = threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "row-$i").encode())) + } + } + } + // Index builder — constructs index while inserts are in flight + launch { + barrier.await() + // Build index at various points during insertion + repeat(10) { + val index = UniqueIndex(cache) { it.id } + synchronized(indices) { indices.add(index) } + // Small yield to let inserts progress + kotlinx.coroutines.yield() + } + } + } + + // After all inserts complete, every index must be consistent with the final cache + assertEquals(totalRows, cache.count()) + for (index in indices) { + // Every row in the cache must be findable in every index + for (i in 0 until totalRows) { + val found = index.find(i) + assertEquals(SampleRow(i, "row-$i"), found, "Index missing row id=$i") + } + } + } + + // ---- ClientCache: concurrent operations across multiple tables ---- + + @Test + fun `client cache concurrent multi table operations`() = runBlocking(Dispatchers.Default) { + val clientCache = ClientCache() + val tableCount = 8 + val barrier = CyclicBarrier(THREAD_COUNT) + + // Each thread works on a different table (round-robin) + coroutineScope { + repeat(THREAD_COUNT) { threadIdx -> + launch { + barrier.await() + val tableName = "table-${threadIdx % tableCount}" + val table = clientCache.getOrCreateTable(tableName) { + TableCache.withPrimaryKey(::decodeSampleRow) { it.id } + } + val base = threadIdx * OPS_PER_THREAD + for (i in base until base + OPS_PER_THREAD) { + table.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "row-$i").encode())) + } + } + } + } + + // Verify all tables exist and have correct counts + val totalRows = THREAD_COUNT * OPS_PER_THREAD + var totalCount = 0 + val allIds = mutableSetOf() + for (t in 0 until tableCount) { + val table = clientCache.getTable("table-$t") + totalCount += table.count() + for (row in table.all()) { + assertTrue(allIds.add(row.id), "Duplicate row id=${row.id} across tables") + } + } + assertEquals(totalRows, totalCount) + assertEquals(totalRows, allIds.size) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/jvmTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/IndexScaleTest.kt b/sdks/kotlin/spacetimedb-sdk/src/jvmTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/IndexScaleTest.kt new file mode 100644 index 00000000000..84f555835ae --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/jvmTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/IndexScaleTest.kt @@ -0,0 +1,451 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client + +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import java.util.concurrent.CyclicBarrier +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.measureTime + +/** + * Large-scale performance tests for UniqueIndex and BTreeIndex. + * These verify correctness and measure performance characteristics + * at row counts well beyond the functional test suite (which uses 2-8K rows). + * + * Run on JVM only — uses real threads for concurrent workloads and + * timing measurements via kotlin.time. + */ +class IndexScaleTest { + + companion object { + private const val SMALL = 1_000 + private const val MEDIUM = 10_000 + private const val LARGE = 50_000 + } + + // ---- UniqueIndex: large-scale population via incremental inserts ---- + + @Test + fun `unique index incremental insert10 k`() { + val cache = createSampleCache() + val index = UniqueIndex(cache) { it.id } + + for (i in 0 until MEDIUM) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "row-$i").encode())) + } + + // Every row must be findable + for (i in 0 until MEDIUM) { + val found = index.find(i) + assertNotNull(found, "Missing row id=$i in UniqueIndex after 10K inserts") + assertEquals(i, found.id) + } + assertEquals(MEDIUM, cache.count()) + } + + @Test + fun `unique index incremental insert50 k`() { + val cache = createSampleCache() + val index = UniqueIndex(cache) { it.id } + + measureTime { + for (i in 0 until LARGE) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "row-$i").encode())) + } + } + + // Spot-check lookups across the range + for (i in listOf(0, LARGE / 4, LARGE / 2, LARGE - 1)) { + val found = index.find(i) + assertNotNull(found, "Missing row id=$i in UniqueIndex after 50K inserts") + assertEquals(i, found.id) + } + assertEquals(LARGE, cache.count()) + + // Measure lookup time over all rows + val lookupTime = measureTime { + for (i in 0 until LARGE) { + index.find(i) + } + } + + // Sanity: 50K lookups should complete in well under 5 seconds + assertTrue(lookupTime.inWholeMilliseconds < 5000, + "50K UniqueIndex lookups took ${lookupTime.inWholeMilliseconds}ms — too slow") + } + + // ---- UniqueIndex: construction from pre-populated cache ---- + + @Test + fun `unique index construction from pre populated cache10 k`() { + val cache = createSampleCache() + for (i in 0 until MEDIUM) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "row-$i").encode())) + } + + // Time how long index construction takes from a full cache + val constructionTime = measureTime { + val index = UniqueIndex(cache) { it.id } + // Verify all rows indexed + assertEquals(SampleRow(0, "row-0"), index.find(0)) + assertEquals(SampleRow(MEDIUM - 1, "row-${MEDIUM - 1}"), index.find(MEDIUM - 1)) + } + + assertTrue(constructionTime.inWholeMilliseconds < 5000, + "UniqueIndex construction from 10K rows took ${constructionTime.inWholeMilliseconds}ms — too slow") + } + + @Test + fun `unique index construction from pre populated cache50 k`() { + val cache = createSampleCache() + for (i in 0 until LARGE) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "row-$i").encode())) + } + + val constructionTime = measureTime { + val index = UniqueIndex(cache) { it.id } + assertEquals(SampleRow(LARGE - 1, "row-${LARGE - 1}"), index.find(LARGE - 1)) + } + + assertTrue(constructionTime.inWholeMilliseconds < 10000, + "UniqueIndex construction from 50K rows took ${constructionTime.inWholeMilliseconds}ms — too slow") + } + + // ---- BTreeIndex: high cardinality (many unique keys) ---- + + @Test + fun `btree index high cardinality10 k`() { + val cache = createSampleCache() + // Each row has a unique name — 10K unique keys, 1 row per key + val index = BTreeIndex(cache) { it.name } + + for (i in 0 until MEDIUM) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "unique-$i").encode())) + } + + // Every key should return exactly 1 row + for (i in 0 until MEDIUM) { + val results = index.filter("unique-$i") + assertEquals(1, results.size, "Expected 1 row for key unique-$i, got ${results.size}") + } + } + + // ---- BTreeIndex: low cardinality (few keys, many rows per key) ---- + + @Test + fun `btree index low cardinality10 k`() { + val cache = createSampleCache() + val groupCount = 10 + val index = BTreeIndex(cache) { it.name } + + for (i in 0 until MEDIUM) { + val group = "group-${i % groupCount}" + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, group).encode())) + } + + // Each group should have MEDIUM / groupCount rows + val expectedPerGroup = MEDIUM / groupCount + for (g in 0 until groupCount) { + val results = index.filter("group-$g") + assertEquals(expectedPerGroup, results.size, + "Group group-$g: expected $expectedPerGroup rows, got ${results.size}") + } + } + + @Test + fun `btree index single key with50 k rows`() { + val cache = createSampleCache() + val index = BTreeIndex(cache) { it.name } + + // All 50K rows share the same key + measureTime { + for (i in 0 until LARGE) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "shared").encode())) + } + } + + // filter() returns all 50K rows + val filterTime = measureTime { + val results = index.filter("shared") + assertEquals(LARGE, results.size) + } + + assertTrue(filterTime.inWholeMilliseconds < 2000, + "BTreeIndex filter returning 50K rows took ${filterTime.inWholeMilliseconds}ms — too slow") + + // Non-existent key returns empty + assertTrue(index.filter("nonexistent").isEmpty()) + } + + // ---- BTreeIndex: construction from pre-populated cache ---- + + @Test + fun `btree index construction from pre populated cache10 k`() { + val cache = createSampleCache() + val groupCount = 100 + for (i in 0 until MEDIUM) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "g-${i % groupCount}").encode())) + } + + val constructionTime = measureTime { + val index = BTreeIndex(cache) { it.name } + val results = index.filter("g-0") + assertEquals(MEDIUM / groupCount, results.size) + } + + assertTrue(constructionTime.inWholeMilliseconds < 5000, + "BTreeIndex construction from 10K rows took ${constructionTime.inWholeMilliseconds}ms — too slow") + } + + // ---- Bulk delete at scale ---- + + @Test + fun `unique index bulk delete50 k`() { + val cache = createSampleCache() + val index = UniqueIndex(cache) { it.id } + + // Insert 50K rows + for (i in 0 until LARGE) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "row-$i").encode())) + } + assertEquals(LARGE, cache.count()) + + // Delete all rows + val deleteTime = measureTime { + for (i in 0 until LARGE) { + val parsed = cache.parseDeletes(buildRowList(SampleRow(i, "row-$i").encode())) + cache.applyDeletes(STUB_CTX, parsed) + } + } + + assertEquals(0, cache.count()) + // All lookups should return null + for (i in listOf(0, LARGE / 2, LARGE - 1)) { + assertEquals(null, index.find(i), "Row id=$i still in index after bulk delete") + } + + assertTrue(deleteTime.inWholeMilliseconds < 10000, + "50K row bulk delete took ${deleteTime.inWholeMilliseconds}ms — too slow") + } + + @Test + fun `btree index bulk delete converges`() { + val cache = createSampleCache() + val groupCount = 10 + val index = BTreeIndex(cache) { it.name } + val rowsPerGroup = MEDIUM / groupCount // 1000 + + for (i in 0 until MEDIUM) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "g-${i % groupCount}").encode())) + } + + // Delete the first half of each group's rows. + // Group g has rows: g, g+10, g+20, ... — delete the first rowsPerGroup/2 of them. + for (g in 0 until groupCount) { + var deleted = 0 + var id = g + while (deleted < rowsPerGroup / 2) { + val parsed = cache.parseDeletes(buildRowList(SampleRow(id, "g-$g").encode())) + cache.applyDeletes(STUB_CTX, parsed) + id += groupCount + deleted++ + } + } + + assertEquals(MEDIUM / 2, cache.count()) + // Each group should have exactly half its rows remaining + for (g in 0 until groupCount) { + val results = index.filter("g-$g") + assertEquals(rowsPerGroup / 2, results.size, + "Group g-$g after bulk delete: expected ${rowsPerGroup / 2}, got ${results.size}") + } + } + + // ---- Mixed read/write workload at scale ---- + + @Test + fun `unique index read heavy mixed workload`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val index = UniqueIndex(cache) { it.id } + + // Pre-populate with 10K rows + for (i in 0 until MEDIUM) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "row-$i").encode())) + } + + val threadCount = 16 + val opsPerThread = 5_000 + val barrier = CyclicBarrier(threadCount) + + coroutineScope { + // 14 reader threads (87.5% reads) + repeat(threadCount - 2) { _ -> + launch { + barrier.await() + repeat(opsPerThread) { i -> + val key = i % MEDIUM + val found = index.find(key) + if (found != null) { + assertEquals(key, found.id, "Read returned wrong row") + } + } + } + } + // 2 writer threads (12.5% writes — insert new rows beyond MEDIUM) + repeat(2) { threadIdx -> + launch { + barrier.await() + val base = MEDIUM + threadIdx * opsPerThread + for (i in base until base + opsPerThread) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "new-$i").encode())) + } + } + } + } + + // All original + new rows must be in the index + val expectedTotal = MEDIUM + 2 * opsPerThread + assertEquals(expectedTotal, cache.count()) + for (i in listOf(0, MEDIUM - 1, MEDIUM, expectedTotal - 1)) { + assertNotNull(index.find(i), "Missing row id=$i after mixed workload") + } + } + + @Test + fun `btree index read heavy mixed workload`() = runBlocking(Dispatchers.Default) { + val cache = createSampleCache() + val groupCount = 50 + val index = BTreeIndex(cache) { it.name } + + // Pre-populate with 10K rows in 50 groups + for (i in 0 until MEDIUM) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "g-${i % groupCount}").encode())) + } + + val threadCount = 16 + val opsPerThread = 2_000 + val barrier = CyclicBarrier(threadCount) + + coroutineScope { + // 14 reader threads + repeat(threadCount - 2) { _ -> + launch { + barrier.await() + repeat(opsPerThread) { i -> + val group = "g-${i % groupCount}" + val results = index.filter(group) + // Group should have at least the pre-populated count + assertTrue(results.isNotEmpty(), "Empty filter result for $group") + } + } + } + // 2 writer threads add rows to existing groups + repeat(2) { threadIdx -> + launch { + barrier.await() + val base = MEDIUM + threadIdx * opsPerThread + for (i in base until base + opsPerThread) { + val group = "g-${i % groupCount}" + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, group).encode())) + } + } + } + } + + val expectedTotal = MEDIUM + 2 * opsPerThread + assertEquals(expectedTotal, cache.count()) + + // Verify group counts converged + val expectedPerGroup = expectedTotal / groupCount + for (g in 0 until groupCount) { + assertEquals(expectedPerGroup, index.filter("g-$g").size, + "Group g-$g count mismatch after mixed workload") + } + } + + // ---- Insert then delete then re-insert at scale ---- + + @Test + fun `unique index insert delete reinsert cycle`() { + val cache = createSampleCache() + val index = UniqueIndex(cache) { it.id } + + // Insert 10K + for (i in 0 until MEDIUM) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "v1-$i").encode())) + } + assertEquals(MEDIUM, cache.count()) + + // Delete all + for (i in 0 until MEDIUM) { + val parsed = cache.parseDeletes(buildRowList(SampleRow(i, "v1-$i").encode())) + cache.applyDeletes(STUB_CTX, parsed) + } + assertEquals(0, cache.count()) + assertEquals(null, index.find(0)) + + // Re-insert with different names + for (i in 0 until MEDIUM) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "v2-$i").encode())) + } + assertEquals(MEDIUM, cache.count()) + + // Index should reflect the new version + for (i in listOf(0, MEDIUM / 2, MEDIUM - 1)) { + val found = index.find(i) + assertNotNull(found, "Missing row id=$i after reinsert") + assertEquals("v2-$i", found.name, "Row id=$i has stale name after reinsert") + } + } + + // ---- Multiple indexes on the same cache ---- + + @Test + fun `multiple indexes on same cache at scale`() { + val cache = createSampleCache() + val uniqueById = UniqueIndex(cache) { it.id } + val btreeByName = BTreeIndex(cache) { it.name } + + val groupCount = 20 + for (i in 0 until MEDIUM) { + cache.applyInserts(STUB_CTX, buildRowList(SampleRow(i, "g-${i % groupCount}").encode())) + } + + // UniqueIndex: every ID findable + for (i in 0 until MEDIUM step 100) { + assertNotNull(uniqueById.find(i), "UniqueIndex missing id=$i") + } + // BTreeIndex: correct group sizes + for (g in 0 until groupCount) { + assertEquals(MEDIUM / groupCount, btreeByName.filter("g-$g").size) + } + + // Delete the first half of each group's rows + val rowsPerGroup = MEDIUM / groupCount + for (g in 0 until groupCount) { + var deleted = 0 + var id = g + while (deleted < rowsPerGroup / 2) { + val parsed = cache.parseDeletes(buildRowList(SampleRow(id, "g-$g").encode())) + cache.applyDeletes(STUB_CTX, parsed) + id += groupCount + deleted++ + } + } + + assertEquals(MEDIUM / 2, cache.count()) + // Deleted rows gone from UniqueIndex (first row of g-0 = id 0) + assertEquals(null, uniqueById.find(0)) + // Second half still present (e.g. id = groupCount * (rowsPerGroup/2) for g-0) + val firstSurvivor = groupCount * (rowsPerGroup / 2) // first surviving row in g-0 + assertNotNull(uniqueById.find(firstSurvivor)) + // BTreeIndex groups halved + for (g in 0 until groupCount) { + assertEquals(rowsPerGroup / 2, btreeByName.filter("g-$g").size) + } + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/jvmTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/CompressionTest.kt b/sdks/kotlin/spacetimedb-sdk/src/jvmTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/CompressionTest.kt new file mode 100644 index 00000000000..0dcb1ce92e8 --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/jvmTest/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/CompressionTest.kt @@ -0,0 +1,75 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol + +import java.io.ByteArrayOutputStream +import java.util.zip.GZIPOutputStream +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertSame +import kotlin.test.assertTrue + +class CompressionTest { + + /** Extract the effective payload bytes from a [DecompressedPayload]. */ + private fun DecompressedPayload.toPayloadBytes(): ByteArray = + data.copyOfRange(offset, data.size) + + @Test + fun `none tag returns payload unchanged`() { + val payload = byteArrayOf(10, 20, 30, 40) + val message = byteArrayOf(Compression.NONE) + payload + + val result = decompressMessage(message) + // Zero-copy: result references the original array with offset=1 + assertSame(result.data, message, "NONE should return the original array (zero-copy)") + assertEquals(1, result.offset) + assertTrue(payload.contentEquals(result.toPayloadBytes())) + } + + @Test + fun `gzip tag decompresses payload`() { + val original = "Hello SpacetimeDB".encodeToByteArray() + + // Compress with java.util.zip + val compressed = ByteArrayOutputStream().use { baos -> + GZIPOutputStream(baos).use { gzip -> + gzip.write(original) + } + baos.toByteArray() + } + + val message = byteArrayOf(Compression.GZIP) + compressed + val result = decompressMessage(message) + assertEquals(0, result.offset) + assertTrue(original.contentEquals(result.toPayloadBytes())) + } + + @Test + fun `empty input throws`() { + assertFailsWith { + decompressMessage(byteArrayOf()) + } + } + + @Test + fun `brotli tag rejects invalid data`() { + // Brotli decoder is wired up — invalid data throws IOException (not IllegalStateException) + assertFailsWith { + decompressMessage(byteArrayOf(Compression.BROTLI, 1, 2, 3)) + } + } + + @Test + fun `unknown tag throws`() { + assertFailsWith { + decompressMessage(byteArrayOf(0x7F, 1, 2, 3)) + } + } + + @Test + fun `none tag empty payload`() { + val message = byteArrayOf(Compression.NONE) + val result = decompressMessage(message) + assertEquals(0, result.size) + } +} diff --git a/sdks/kotlin/spacetimedb-sdk/src/nativeMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/Compression.native.kt b/sdks/kotlin/spacetimedb-sdk/src/nativeMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/Compression.native.kt new file mode 100644 index 00000000000..444879fec3a --- /dev/null +++ b/sdks/kotlin/spacetimedb-sdk/src/nativeMain/kotlin/com/clockworklabs/spacetimedb_kotlin_sdk/shared_client/protocol/Compression.native.kt @@ -0,0 +1,20 @@ +package com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.CompressionMode + +internal actual val defaultCompressionMode: CompressionMode = CompressionMode.NONE + +internal actual val availableCompressionModes: Set = + setOf(CompressionMode.NONE) + +internal actual fun decompressMessage(data: ByteArray): DecompressedPayload { + require(data.isNotEmpty()) { "Empty message" } + + return when (val tag = data[0]) { + Compression.NONE -> DecompressedPayload(data, offset = 1) + // https://github.com/google/brotli/issues/1123 + Compression.BROTLI -> error("Brotli compression not supported on native.") + Compression.GZIP -> error("Gzip compression not supported on native.") + else -> error("Unknown compression tag: $tag") + } +} diff --git a/skills/spacetimedb-kotlin/SKILL.md b/skills/spacetimedb-kotlin/SKILL.md new file mode 100644 index 00000000000..c9ab6e93008 --- /dev/null +++ b/skills/spacetimedb-kotlin/SKILL.md @@ -0,0 +1,316 @@ +--- +name: spacetimedb-kotlin +description: Build Kotlin Multiplatform clients for SpacetimeDB. Covers KMP SDK integration for Android, JVM Desktop, and iOS/Native. +license: Apache-2.0 +metadata: + author: clockworklabs + version: "2.1" + tested_with: "SpacetimeDB 2.1, JDK 21+, Kotlin 2.1" +--- + +# SpacetimeDB Kotlin SDK + +Build real-time Kotlin Multiplatform clients that connect directly to SpacetimeDB modules. The SDK provides type-safe database access, automatic synchronization, and reactive updates for Android, JVM Desktop, and iOS/Native apps. + +The server module is written in Rust (or C#/TypeScript). Kotlin is a **client-only** SDK — there is no `crates/bindings-kotlin` for server-side modules. + +--- + +## HALLUCINATED APIs — DO NOT USE + +**These APIs DO NOT EXIST. LLMs frequently hallucinate them.** + +```kotlin +// WRONG — these builder methods do not exist +DbConnection.Builder().withHost("localhost") // Use withUri("ws://localhost:3000") +DbConnection.Builder().withDatabase("my-db") // Use withDatabaseName("my-db") +DbConnection.Builder().withModule(Module) // Use withModuleBindings() (generated extension) +DbConnection.Builder().connect() // Use build() (suspending) + +// WRONG — blocking build +val conn = DbConnection.Builder().build() // build() is suspend — must be in coroutine + +// WRONG — table access patterns +conn.db.Person // Wrong casing — use generated accessor name +conn.tables.person // No .tables — use conn.db.person +conn.db.person.get(id) // No .get() — use index: conn.db.person.id.find(id) +conn.db.person.findById(id) // No .findById() — use conn.db.person.id.find(id) +conn.db.person.query("SELECT ...") // No SQL on client — use subscriptions + query builder + +// WRONG — callback signatures +conn.db.person.onInsert { person -> } // Missing EventContext: { ctx, person -> } +conn.db.person.onUpdate { old, new -> } // Missing EventContext: { ctx, old, new -> } +conn.db.person.onInsert(::handleInsert) // OK — function references work if signature matches (EventContext, Person) -> Unit + +// WRONG — subscription patterns +conn.subscribe("SELECT * FROM person") // No direct subscribe — use subscriptionBuilder() +conn.subscriptionBuilder().subscribe("SELECT ...") // Works, but prefer typed query builder for compile-time safety + +// WRONG — reducer call patterns +conn.call("add", "Alice") // No generic call — use conn.reducers.add("Alice") +conn.reducers.add("Alice").await() // Reducers don't return futures — use one-shot callback + +// WRONG — non-existent types +import spacetimedb.Identity // Wrong package — use com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import spacetimedb.DbConnection // Wrong — use com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection + +// WRONG — generating bindings to src/ +// The Gradle plugin generates to build/generated/spacetimedb/bindings/, NOT src/main/kotlin/module_bindings/ +``` + +### CORRECT PATTERNS + +```kotlin +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.use +import io.ktor.client.HttpClient +import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.plugins.websocket.WebSockets +import module_bindings.* + +suspend fun main() { + val httpClient = HttpClient(OkHttp) { install(WebSockets) } + + DbConnection.Builder() + .withHttpClient(httpClient) + .withUri("ws://localhost:3000") + .withDatabaseName(SpacetimeConfig.DATABASE_NAME) + .withModuleBindings() // Generated extension — registers module descriptor + .onConnect { conn, identity, token -> + conn.db.person.onInsert { ctx, person -> + println("Inserted: ${person.name}") + } + + conn.subscriptionBuilder() + .subscribeToAllTables() + + conn.reducers.add("Alice") { ctx -> + println("status=${ctx.status}") + } + } + .onDisconnect { _, error -> + println("Disconnected: ${error?.message ?: "clean"}") + } + .build() + .use { delay(Duration.INFINITE) } +} +``` + +--- + +## Common Mistakes Table + +| Wrong | Right | Why | +|-------|-------|-----| +| `DbConnection.Builder().build()` outside coroutine | Wrap in `runBlocking` or `launch` | `build()` is `suspend` | +| Forgetting `install(WebSockets)` on HttpClient | `HttpClient(OkHttp) { install(WebSockets) }` | SDK needs WebSocket support | +| Using `withModuleDescriptor(Module)` | Use `withModuleBindings()` | Generated extension handles registration | +| Callbacks without EventContext | `{ ctx, row -> }` not `{ row -> }` | All callbacks receive EventContext first | +| `onUpdate` on table without primary key | Only available on `RemotePersistentTableWithPrimaryKey` | Need `#[primary_key]` on server table | +| Calling `conn.db` from wrong thread | SDK is coroutine-safe via atomic state | Use from any coroutine scope | +| Generating bindings to `src/` | Gradle plugin generates to `build/generated/spacetimedb/bindings/` | Bindings are build artifacts, not source | +| Using `includeBuild` without local SDK checkout | Required until SDK is published on Maven Central | Templates have placeholder comments | + +--- + +## Hard Requirements + +1. **JDK 21+** — required by the SDK and Gradle plugin +2. **Ktor HttpClient with WebSockets** — must `install(WebSockets)` on the client +3. **`build()` is suspending** — must be called from a coroutine +4. **`withModuleBindings()`** — generated extension, call on builder to register module +5. **`SpacetimeConfig.DATABASE_NAME`** — generated constant, use for database name +6. **Callbacks always receive `EventContext` as first param** — `{ ctx, row -> }` +7. **`onUpdate` requires primary key** — only on `RemotePersistentTableWithPrimaryKey` +8. **Gradle plugin auto-generates bindings** — no manual `spacetime generate` needed when using the plugin +9. **Server module is Rust** — templates use Rust server modules, not Kotlin + +--- + +## Client SDK API + +### DbConnection.Builder + +```kotlin +val conn = DbConnection.Builder() + .withHttpClient(httpClient) // Required: Ktor HttpClient + .withUri("ws://localhost:3000") // Required: WebSocket URL + .withDatabaseName("my-database") // Required: database name + .withToken(savedToken) // Optional: auth token for reconnection + .withModuleBindings() // Required: generated extension + .onConnect { conn, identity, token -> } // Connected callback + .onDisconnect { conn, error -> } // Disconnected callback + .onConnectError { conn, error -> } // Connection failed callback + .build() // Suspending — returns DbConnection +``` + +### Connection Lifecycle + +```kotlin +// Keep alive with automatic cleanup +conn.use { + delay(Duration.INFINITE) +} + +// Manual disconnect +conn.disconnect() +``` + +### Table Access (Client Cache) + +```kotlin +// Read cached rows +conn.db.person.count() // Int +conn.db.person.all() // List +conn.db.person.iter() // Sequence + +// Index lookups (generated per-table) +conn.db.person.id.find(42u) // Person? — unique index +conn.db.person.nameIdx.filter("Alice") // Set — BTree index +``` + +### Row Callbacks + +```kotlin +conn.db.person.onInsert { ctx, person -> } +conn.db.person.onDelete { ctx, person -> } +conn.db.person.onUpdate { ctx, oldPerson, newPerson -> } // PK tables only +conn.db.person.onBeforeDelete { ctx, person -> } + +// Remove callback +val cb: (EventContext, Person) -> Unit = { ctx, p -> println(p) } +conn.db.person.onInsert(cb) +conn.db.person.removeOnInsert(cb) +``` + +### Reducers + +```kotlin +// Call a reducer +conn.reducers.add("Alice") + +// Call with one-shot callback +conn.reducers.add("Alice") { ctx -> + println("status=${ctx.status}") +} + +// Observe all calls to a reducer +conn.reducers.onAdd { ctx, name -> + println("add($name) status=${ctx.status}") +} +``` + +### Subscriptions + +```kotlin +// Subscribe to all tables +conn.subscriptionBuilder() + .onError { _, error -> println(error) } + .subscribeToAllTables() + +// Type-safe query builder +conn.subscriptionBuilder() + .addQuery { qb -> qb.person().where { cols -> cols.name.eq("Alice") } } + .onApplied { println("Applied") } + .subscribe() + +// Query builder operations +qb.person() + .where { cols -> cols.name.eq("Alice").and(cols.id.gt(0u)) } + +qb.person() + .leftSemijoin(qb.team()) { person, team -> + person.teamId.eq(team.id) + } +``` + +### Identity + +```kotlin +identity.toHexString() // Hex string representation +``` + +--- + +## Type Mappings + +| SpacetimeDB | Kotlin | +|-------------|--------| +| `bool` | `Boolean` | +| `u8`/`u16`/`u32`/`u64` | `UByte`/`UShort`/`UInt`/`ULong` | +| `i8`/`i16`/`i32`/`i64` | `Byte`/`Short`/`Int`/`Long` | +| `u128`/`u256` | `UInt128`/`UInt256` | +| `i128`/`i256` | `Int128`/`Int256` | +| `f32`/`f64` | `Float`/`Double` | +| `String` | `String` | +| `Vec` | `ByteArray` | +| `Vec` | `List` | +| `Option` | `T?` | +| `Identity` | `Identity` | +| `ConnectionId` | `ConnectionId` | +| `Timestamp` | `Timestamp` | +| `TimeDuration` | `TimeDuration` | +| `ScheduleAt` | `ScheduleAt` | +| `Uuid` | `SpacetimeUuid` | +| Product types | `data class` | +| Sum types (all unit) | `enum class` | +| Sum types (mixed) | `sealed interface` | + +--- + +## Project Structure + +### basic-kt (JVM-only) + +``` +my-app/ +├── spacetimedb/ # Rust server module +│ ├── Cargo.toml +│ └── src/lib.rs +├── src/main/kotlin/ +│ └── Main.kt # JVM client +├── build/generated/spacetimedb/ +│ └── bindings/ # Auto-generated (by Gradle plugin) +├── build.gradle.kts +├── settings.gradle.kts +└── spacetime.json +``` + +### compose-kt (KMP: Android + Desktop) + +``` +my-app/ +├── spacetimedb/ # Rust server module +├── androidApp/ # Android entry point (MainActivity) +├── desktopApp/ # Desktop entry point (main.kt) +├── sharedClient/ # Shared KMP module (UI + SpacetimeDB client) +│ └── src/ +│ ├── commonMain/kotlin/app/ +│ │ ├── AppViewModel.kt +│ │ ├── ChatRepository.kt +│ │ └── composable/ # Compose UI screens +│ ├── androidMain/ # Android-specific (TokenStore) +│ └── jvmMain/ # Desktop-specific (TokenStore) +└── spacetime.json +``` + +--- + +## Commands + +```bash +# Create project from template +spacetime init --template basic-kt --project-path ./my-app --non-interactive my-app + +# Build and run (interactive — requires terminal) +spacetime dev + +# Generate bindings manually (not needed with Gradle plugin) +spacetime generate --lang kotlin --out-dir src/main/kotlin/module_bindings --module-path spacetimedb + +# Build Kotlin client +./gradlew compileKotlin + +# Run Kotlin client +./gradlew run +``` diff --git a/templates/basic-kt/.gitignore b/templates/basic-kt/.gitignore new file mode 100644 index 00000000000..0838237ae30 --- /dev/null +++ b/templates/basic-kt/.gitignore @@ -0,0 +1,26 @@ +# Gradle +.gradle/ +**/build/ + +# Kotlin +.kotlin/ + +# IDE +.idea/ +*.iml +.vscode/ + +# SpacetimeDB server module build artifacts +target/ + +# OS +.DS_Store + +# Logs +*.log + +# Local configuration +local.properties +spacetime.local.json +.env +.env.local diff --git a/templates/basic-kt/.template.json b/templates/basic-kt/.template.json new file mode 100644 index 00000000000..d006aea6b9d --- /dev/null +++ b/templates/basic-kt/.template.json @@ -0,0 +1,5 @@ +{ + "description": "A basic Kotlin client and Rust server template with only stubs for code", + "client_lang": "kotlin", + "server_lang": "rust" +} diff --git a/templates/basic-kt/build.gradle.kts b/templates/basic-kt/build.gradle.kts new file mode 100644 index 00000000000..ba832f502be --- /dev/null +++ b/templates/basic-kt/build.gradle.kts @@ -0,0 +1,20 @@ +plugins { + alias(libs.plugins.kotlinJvm) + alias(libs.plugins.spacetimedb) + application +} + +kotlin { + jvmToolchain(21) +} + +application { + mainClass.set("MainKt") +} + +dependencies { + implementation(libs.spacetimedb.sdk) + implementation(libs.kotlinx.coroutines.core) + implementation(libs.ktor.client.okhttp) + implementation(libs.ktor.client.websockets) +} diff --git a/templates/basic-kt/gradle.properties b/templates/basic-kt/gradle.properties new file mode 100644 index 00000000000..822eb53a453 --- /dev/null +++ b/templates/basic-kt/gradle.properties @@ -0,0 +1,8 @@ +#Kotlin +kotlin.code.style=official +kotlin.daemon.jvmargs=-Xmx3072M + +#Gradle +org.gradle.jvmargs=-Xmx4096M -Dfile.encoding=UTF-8 +org.gradle.configuration-cache=true +org.gradle.caching=true diff --git a/templates/basic-kt/gradle/gradle-daemon-jvm.properties b/templates/basic-kt/gradle/gradle-daemon-jvm.properties new file mode 100644 index 00000000000..6c1139ec06a --- /dev/null +++ b/templates/basic-kt/gradle/gradle-daemon-jvm.properties @@ -0,0 +1,12 @@ +#This file is generated by updateDaemonJvm +toolchainUrl.FREE_BSD.AARCH64=https\://api.foojay.io/disco/v3.0/ids/ec7520a1e057cd116f9544c42142a16b/redirect +toolchainUrl.FREE_BSD.X86_64=https\://api.foojay.io/disco/v3.0/ids/4c4f879899012ff0a8b2e2117df03b0e/redirect +toolchainUrl.LINUX.AARCH64=https\://api.foojay.io/disco/v3.0/ids/ec7520a1e057cd116f9544c42142a16b/redirect +toolchainUrl.LINUX.X86_64=https\://api.foojay.io/disco/v3.0/ids/4c4f879899012ff0a8b2e2117df03b0e/redirect +toolchainUrl.MAC_OS.AARCH64=https\://api.foojay.io/disco/v3.0/ids/73bcfb608d1fde9fb62e462f834a3299/redirect +toolchainUrl.MAC_OS.X86_64=https\://api.foojay.io/disco/v3.0/ids/846ee0d876d26a26f37aa1ce8de73224/redirect +toolchainUrl.UNIX.AARCH64=https\://api.foojay.io/disco/v3.0/ids/ec7520a1e057cd116f9544c42142a16b/redirect +toolchainUrl.UNIX.X86_64=https\://api.foojay.io/disco/v3.0/ids/4c4f879899012ff0a8b2e2117df03b0e/redirect +toolchainUrl.WINDOWS.AARCH64=https\://api.foojay.io/disco/v3.0/ids/9482ddec596298c84656d31d16652665/redirect +toolchainUrl.WINDOWS.X86_64=https\://api.foojay.io/disco/v3.0/ids/39701d92e1756bb2f141eb67cd4c660e/redirect +toolchainVersion=21 diff --git a/templates/basic-kt/gradle/libs.versions.toml b/templates/basic-kt/gradle/libs.versions.toml new file mode 100644 index 00000000000..045249eee80 --- /dev/null +++ b/templates/basic-kt/gradle/libs.versions.toml @@ -0,0 +1,15 @@ +[versions] +kotlin = "2.3.10" +kotlinx-coroutines = "1.10.2" +ktor = "3.4.1" +spacetimedb-sdk = "0.1.0" + +[libraries] +spacetimedb-sdk = { module = "com.clockworklabs:spacetimedb-sdk", version.ref = "spacetimedb-sdk" } +kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "kotlinx-coroutines" } +ktor-client-okhttp = { module = "io.ktor:ktor-client-okhttp", version.ref = "ktor" } +ktor-client-websockets = { module = "io.ktor:ktor-client-websockets", version.ref = "ktor" } + +[plugins] +kotlinJvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "kotlin" } +spacetimedb = { id = "com.clockworklabs.spacetimedb", version.ref = "spacetimedb-sdk" } diff --git a/templates/basic-kt/gradle/wrapper/gradle-wrapper.jar b/templates/basic-kt/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 00000000000..2c3521197d7 Binary files /dev/null and b/templates/basic-kt/gradle/wrapper/gradle-wrapper.jar differ diff --git a/templates/basic-kt/gradle/wrapper/gradle-wrapper.properties b/templates/basic-kt/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000000..37f78a6af83 --- /dev/null +++ b/templates/basic-kt/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.3.1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/templates/basic-kt/gradlew b/templates/basic-kt/gradlew new file mode 100755 index 00000000000..f5feea6d6b1 --- /dev/null +++ b/templates/basic-kt/gradlew @@ -0,0 +1,252 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s +' "$PWD" ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/templates/basic-kt/gradlew.bat b/templates/basic-kt/gradlew.bat new file mode 100644 index 00000000000..9b42019c791 --- /dev/null +++ b/templates/basic-kt/gradlew.bat @@ -0,0 +1,94 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/templates/basic-kt/settings.gradle.kts b/templates/basic-kt/settings.gradle.kts new file mode 100644 index 00000000000..9563fce5d47 --- /dev/null +++ b/templates/basic-kt/settings.gradle.kts @@ -0,0 +1,25 @@ +@file:Suppress("UnstableApiUsage") + +rootProject.name = "basic-kt" + +pluginManagement { + repositories { + mavenCentral() + gradlePluginPortal() + } + // TODO: Replace with published Maven coordinates once the SDK is available on Maven Central. + // includeBuild("/spacetimedb-gradle-plugin") +} + +dependencyResolutionManagement { + repositories { + mavenCentral() + } +} + +plugins { + id("org.gradle.toolchains.foojay-resolver-convention") version "1.0.0" +} + +// TODO: Replace with published Maven coordinates once the SDK is available on Maven Central. +// includeBuild("") diff --git a/templates/basic-kt/spacetime.json b/templates/basic-kt/spacetime.json new file mode 100644 index 00000000000..e9641b9265a --- /dev/null +++ b/templates/basic-kt/spacetime.json @@ -0,0 +1,4 @@ +{ + "server": "local", + "module-path": "./spacetimedb" +} diff --git a/templates/basic-kt/spacetimedb/Cargo.toml b/templates/basic-kt/spacetimedb/Cargo.toml new file mode 100644 index 00000000000..448f0e7e1b0 --- /dev/null +++ b/templates/basic-kt/spacetimedb/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "basic-kt" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +spacetimedb = "2.0" +log = "0.4" diff --git a/templates/basic-kt/spacetimedb/src/lib.rs b/templates/basic-kt/spacetimedb/src/lib.rs new file mode 100644 index 00000000000..c415537c423 --- /dev/null +++ b/templates/basic-kt/spacetimedb/src/lib.rs @@ -0,0 +1,37 @@ +use spacetimedb::{ReducerContext, Table}; + +#[spacetimedb::table(accessor = person, public)] +pub struct Person { + #[primary_key] + #[auto_inc] + id: u64, + name: String, +} + +#[spacetimedb::reducer(init)] +pub fn init(_ctx: &ReducerContext) { + // Called when the module is initially published +} + +#[spacetimedb::reducer(client_connected)] +pub fn identity_connected(_ctx: &ReducerContext) { + // Called everytime a new client connects +} + +#[spacetimedb::reducer(client_disconnected)] +pub fn identity_disconnected(_ctx: &ReducerContext) { + // Called everytime a client disconnects +} + +#[spacetimedb::reducer] +pub fn add(ctx: &ReducerContext, name: String) { + ctx.db.person().insert(Person { id: 0, name }); +} + +#[spacetimedb::reducer] +pub fn say_hello(ctx: &ReducerContext) { + for person in ctx.db.person().iter() { + log::info!("Hello, {}!", person.name); + } + log::info!("Hello, World!"); +} diff --git a/templates/basic-kt/src/main/kotlin/Main.kt b/templates/basic-kt/src/main/kotlin/Main.kt new file mode 100644 index 00000000000..7fb52027053 --- /dev/null +++ b/templates/basic-kt/src/main/kotlin/Main.kt @@ -0,0 +1,55 @@ +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.use +import io.ktor.client.HttpClient +import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.plugins.websocket.WebSockets +import kotlinx.coroutines.delay +import module_bindings.db +import module_bindings.reducers +import module_bindings.subscribeToAllTables +import module_bindings.withModuleBindings +import kotlin.time.Duration.Companion.seconds + +suspend fun main() { + val host = System.getenv("SPACETIMEDB_HOST") ?: "ws://localhost:3000" + val httpClient = HttpClient(OkHttp) { install(WebSockets) } + + DbConnection.Builder() + .withHttpClient(httpClient) + .withUri(host) + .withDatabaseName(module_bindings.SpacetimeConfig.DATABASE_NAME) + .withModuleBindings() + .onConnect { conn, identity, _ -> + println("Connected to SpacetimeDB!") + println("Identity: ${identity.toHexString().take(16)}...") + + conn.db.person.onInsert { _, person -> + println("New person: ${person.name}") + } + + conn.reducers.onAdd { ctx, name -> + println("[onAdd] Added person: $name (status=${ctx.status})") + } + + conn.subscriptionBuilder() + .onError { _, error -> println("Subscription error: $error") } + .subscribeToAllTables() + + conn.reducers.add("Alice") { ctx -> + println("[one-shot] Add completed: status=${ctx.status}") + conn.reducers.sayHello() + } + } + .onDisconnect { _, error -> + if (error != null) { + println("Disconnected with error: $error") + } else { + println("Disconnected") + } + } + .onConnectError { _, error -> + println("Connection error: $error") + } + .build() + .use { delay(5.seconds) } +} diff --git a/templates/compose-kt/.gitignore b/templates/compose-kt/.gitignore new file mode 100644 index 00000000000..0838237ae30 --- /dev/null +++ b/templates/compose-kt/.gitignore @@ -0,0 +1,26 @@ +# Gradle +.gradle/ +**/build/ + +# Kotlin +.kotlin/ + +# IDE +.idea/ +*.iml +.vscode/ + +# SpacetimeDB server module build artifacts +target/ + +# OS +.DS_Store + +# Logs +*.log + +# Local configuration +local.properties +spacetime.local.json +.env +.env.local diff --git a/templates/compose-kt/.template.json b/templates/compose-kt/.template.json new file mode 100644 index 00000000000..fb12e14f155 --- /dev/null +++ b/templates/compose-kt/.template.json @@ -0,0 +1,5 @@ +{ + "description": "A Compose Multiplatform (Android + Desktop) chat client with a Rust server", + "client_lang": "kotlin", + "server_lang": "rust" +} diff --git a/templates/compose-kt/androidApp/build.gradle.kts b/templates/compose-kt/androidApp/build.gradle.kts new file mode 100644 index 00000000000..50c31976e2a --- /dev/null +++ b/templates/compose-kt/androidApp/build.gradle.kts @@ -0,0 +1,31 @@ +plugins { + alias(libs.plugins.androidApplication) + alias(libs.plugins.composeCompiler) +} + +android { + namespace = "com.clockworklabs.spacetimedb_compose_kt" + compileSdk { + version = release(libs.versions.android.compileSdk.get().toInt()) + } + + defaultConfig { + applicationId = "com.clockworklabs.spacetimedb_compose_kt" + minSdk = libs.versions.android.minSdk.get().toInt() + targetSdk = libs.versions.android.targetSdk.get().toInt() + versionCode = 1 + versionName = "1.0" + } + packaging { + resources { + excludes += "/META-INF/{AL2.0,LGPL2.1}" + } + } +} + +dependencies { + implementation(projects.sharedClient) + implementation(libs.androidx.activity.compose) + implementation(libs.ktor.client.okhttp) + implementation(libs.ktor.client.websockets) +} diff --git a/templates/compose-kt/androidApp/src/main/AndroidManifest.xml b/templates/compose-kt/androidApp/src/main/AndroidManifest.xml new file mode 100644 index 00000000000..aa28973654a --- /dev/null +++ b/templates/compose-kt/androidApp/src/main/AndroidManifest.xml @@ -0,0 +1,26 @@ + + + + + + + + + + + + + + + diff --git a/templates/compose-kt/androidApp/src/main/kotlin/MainActivity.kt b/templates/compose-kt/androidApp/src/main/kotlin/MainActivity.kt new file mode 100644 index 00000000000..8dd05af4501 --- /dev/null +++ b/templates/compose-kt/androidApp/src/main/kotlin/MainActivity.kt @@ -0,0 +1,38 @@ +package com.clockworklabs.spacetimedb_compose_kt + +import android.os.Bundle +import androidx.activity.ComponentActivity +import androidx.activity.compose.setContent +import androidx.activity.enableEdgeToEdge +import androidx.lifecycle.ViewModelProvider +import androidx.lifecycle.ViewModelProvider.AndroidViewModelFactory.Companion.APPLICATION_KEY +import androidx.lifecycle.viewmodel.initializer +import androidx.lifecycle.viewmodel.viewModelFactory +import app.AppViewModel +import app.ChatRepository +import app.TokenStore +import app.composable.App +import io.ktor.client.HttpClient +import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.plugins.websocket.WebSockets + +class MainActivity : ComponentActivity() { + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + enableEdgeToEdge() + + val factory = viewModelFactory { + initializer { + val context = this[APPLICATION_KEY]!! + val httpClient = HttpClient(OkHttp) { install(WebSockets) } + val tokenStore = TokenStore(context) + val repository = ChatRepository(httpClient, tokenStore) + // 10.0.2.2 is the Android emulator's alias for the host machine's loopback. + // For physical devices, replace with your machine's LAN IP (e.g. "ws://192.168.1.x:3000"). + AppViewModel(repository, defaultHost = "ws://10.0.2.2:3000") + } + } + val viewModel = ViewModelProvider(this, factory)[AppViewModel::class.java] + setContent { App(viewModel) } + } +} diff --git a/templates/compose-kt/androidApp/src/main/res/drawable-v24/ic_launcher_foreground.xml b/templates/compose-kt/androidApp/src/main/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 00000000000..2b068d11462 --- /dev/null +++ b/templates/compose-kt/androidApp/src/main/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,30 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/templates/compose-kt/androidApp/src/main/res/drawable/ic_launcher_background.xml b/templates/compose-kt/androidApp/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 00000000000..e93e11adef9 --- /dev/null +++ b/templates/compose-kt/androidApp/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/templates/compose-kt/androidApp/src/main/res/mipmap-anydpi-v26/ic_launcher.xml b/templates/compose-kt/androidApp/src/main/res/mipmap-anydpi-v26/ic_launcher.xml new file mode 100644 index 00000000000..eca70cfe52e --- /dev/null +++ b/templates/compose-kt/androidApp/src/main/res/mipmap-anydpi-v26/ic_launcher.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/templates/compose-kt/androidApp/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml b/templates/compose-kt/androidApp/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml new file mode 100644 index 00000000000..eca70cfe52e --- /dev/null +++ b/templates/compose-kt/androidApp/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/templates/compose-kt/androidApp/src/main/res/values/strings.xml b/templates/compose-kt/androidApp/src/main/res/values/strings.xml new file mode 100644 index 00000000000..46550d479e9 --- /dev/null +++ b/templates/compose-kt/androidApp/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ + + compose-kt + \ No newline at end of file diff --git a/templates/compose-kt/androidApp/src/main/res/xml/network_security_config.xml b/templates/compose-kt/androidApp/src/main/res/xml/network_security_config.xml new file mode 100644 index 00000000000..0cc4b5ea367 --- /dev/null +++ b/templates/compose-kt/androidApp/src/main/res/xml/network_security_config.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/templates/compose-kt/build.gradle.kts b/templates/compose-kt/build.gradle.kts new file mode 100644 index 00000000000..984245106fe --- /dev/null +++ b/templates/compose-kt/build.gradle.kts @@ -0,0 +1,24 @@ +plugins { + alias(libs.plugins.androidApplication) apply false + alias(libs.plugins.androidKotlinMultiplatformLibrary) apply false + alias(libs.plugins.kotlinJvm) apply false + alias(libs.plugins.kotlinMultiplatform) apply false + alias(libs.plugins.composeMultiplatform) apply false + alias(libs.plugins.composeCompiler) apply false + alias(libs.plugins.spacetimedb) apply false +} + +subprojects { + afterEvaluate { + plugins.withId("org.jetbrains.kotlin.multiplatform") { + extensions.configure { + jvmToolchain(21) + } + } + plugins.withId("org.jetbrains.kotlin.jvm") { + extensions.configure { + jvmToolchain(21) + } + } + } +} diff --git a/templates/compose-kt/desktopApp/build.gradle.kts b/templates/compose-kt/desktopApp/build.gradle.kts new file mode 100644 index 00000000000..1e20d29d222 --- /dev/null +++ b/templates/compose-kt/desktopApp/build.gradle.kts @@ -0,0 +1,28 @@ +import org.jetbrains.compose.desktop.application.dsl.TargetFormat + +plugins { + alias(libs.plugins.kotlinJvm) + alias(libs.plugins.composeMultiplatform) + alias(libs.plugins.composeCompiler) +} + +dependencies { + implementation(projects.sharedClient) + implementation(compose.desktop.currentOs) + implementation(libs.androidx.lifecycle.viewmodel) + implementation(libs.kotlinx.coroutines.swing) + implementation(libs.ktor.client.okhttp) + implementation(libs.ktor.client.websockets) +} + +compose.desktop { + application { + mainClass = "MainKt" + + nativeDistributions { + targetFormats(TargetFormat.Dmg, TargetFormat.Msi, TargetFormat.Deb) + packageName = "com.clockworklabs.spacetimedb_compose_kt" + packageVersion = "1.0.0" + } + } +} diff --git a/templates/compose-kt/desktopApp/src/main/kotlin/main.kt b/templates/compose-kt/desktopApp/src/main/kotlin/main.kt new file mode 100644 index 00000000000..fec3b135fd0 --- /dev/null +++ b/templates/compose-kt/desktopApp/src/main/kotlin/main.kt @@ -0,0 +1,27 @@ +import androidx.compose.runtime.remember +import androidx.compose.ui.window.Window +import androidx.compose.ui.window.application +import app.AppViewModel +import app.ChatRepository +import app.TokenStore +import app.composable.App +import io.ktor.client.HttpClient +import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.plugins.websocket.WebSockets +fun main() = application { + val httpClient = remember { HttpClient(OkHttp) { install(WebSockets) } } + val tokenStore = remember { TokenStore() } + val repository = remember { ChatRepository(httpClient, tokenStore) } + val viewModel = remember { AppViewModel(repository, defaultHost = "ws://localhost:3000") } + Window( + onCloseRequest = { + // ViewModel.onCleared handles disconnect via runBlocking. + // Just close the HTTP client and exit. + httpClient.close() + exitApplication() + }, + title = "SpacetimeDB Chat", + ) { + App(viewModel) + } +} diff --git a/templates/compose-kt/gradle.properties b/templates/compose-kt/gradle.properties new file mode 100644 index 00000000000..9281d52d140 --- /dev/null +++ b/templates/compose-kt/gradle.properties @@ -0,0 +1,9 @@ +#Kotlin +kotlin.code.style=official +kotlin.daemon.jvmargs=-Xmx3072M +kotlin.native.ignoreDisabledTargets=true + +#Gradle +org.gradle.jvmargs=-Xmx4096M -Dfile.encoding=UTF-8 +org.gradle.configuration-cache=true +org.gradle.caching=true diff --git a/templates/compose-kt/gradle/gradle-daemon-jvm.properties b/templates/compose-kt/gradle/gradle-daemon-jvm.properties new file mode 100644 index 00000000000..6c1139ec06a --- /dev/null +++ b/templates/compose-kt/gradle/gradle-daemon-jvm.properties @@ -0,0 +1,12 @@ +#This file is generated by updateDaemonJvm +toolchainUrl.FREE_BSD.AARCH64=https\://api.foojay.io/disco/v3.0/ids/ec7520a1e057cd116f9544c42142a16b/redirect +toolchainUrl.FREE_BSD.X86_64=https\://api.foojay.io/disco/v3.0/ids/4c4f879899012ff0a8b2e2117df03b0e/redirect +toolchainUrl.LINUX.AARCH64=https\://api.foojay.io/disco/v3.0/ids/ec7520a1e057cd116f9544c42142a16b/redirect +toolchainUrl.LINUX.X86_64=https\://api.foojay.io/disco/v3.0/ids/4c4f879899012ff0a8b2e2117df03b0e/redirect +toolchainUrl.MAC_OS.AARCH64=https\://api.foojay.io/disco/v3.0/ids/73bcfb608d1fde9fb62e462f834a3299/redirect +toolchainUrl.MAC_OS.X86_64=https\://api.foojay.io/disco/v3.0/ids/846ee0d876d26a26f37aa1ce8de73224/redirect +toolchainUrl.UNIX.AARCH64=https\://api.foojay.io/disco/v3.0/ids/ec7520a1e057cd116f9544c42142a16b/redirect +toolchainUrl.UNIX.X86_64=https\://api.foojay.io/disco/v3.0/ids/4c4f879899012ff0a8b2e2117df03b0e/redirect +toolchainUrl.WINDOWS.AARCH64=https\://api.foojay.io/disco/v3.0/ids/9482ddec596298c84656d31d16652665/redirect +toolchainUrl.WINDOWS.X86_64=https\://api.foojay.io/disco/v3.0/ids/39701d92e1756bb2f141eb67cd4c660e/redirect +toolchainVersion=21 diff --git a/templates/compose-kt/gradle/libs.versions.toml b/templates/compose-kt/gradle/libs.versions.toml new file mode 100644 index 00000000000..8a7d1d628cd --- /dev/null +++ b/templates/compose-kt/gradle/libs.versions.toml @@ -0,0 +1,39 @@ +[versions] +agp = "9.1.0" +android-compileSdk = "36" +android-minSdk = "26" +android-targetSdk = "36" +androidx-activityCompose = "1.12.4" +androidx-lifecycle = "2.9.6" +compose-multiplatform = "1.10.2" +kotlin = "2.3.10" +kotlinx-coroutines = "1.10.2" +kotlinxCollectionsImmutable = "0.4.0" +dateTime = "0.7.1" +ktor = "3.4.1" +spacetimedb-sdk = "0.1.0" +material3 = "1.9.0" + +[libraries] +androidx-activity-compose = { module = "androidx.activity:activity-compose", version.ref = "androidx-activityCompose" } +androidx-lifecycle-runtime-compose = { group = "org.jetbrains.androidx.lifecycle", name = "lifecycle-runtime-compose", version.ref = "androidx-lifecycle" } +androidx-lifecycle-viewmodel = { group = "org.jetbrains.androidx.lifecycle", name = "lifecycle-viewmodel", version.ref = "androidx-lifecycle" } +kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "kotlinx-coroutines" } +kotlinx-coroutines-swing = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-swing", version.ref = "kotlinx-coroutines" } +kotlinx-collections-immutable = { module = "org.jetbrains.kotlinx:kotlinx-collections-immutable", version.ref = "kotlinxCollectionsImmutable" } +kotlinx-datetime = { module = "org.jetbrains.kotlinx:kotlinx-datetime", version.ref = "dateTime" } +ktor-client-core = { module = "io.ktor:ktor-client-core", version.ref = "ktor" } +ktor-client-okhttp = { module = "io.ktor:ktor-client-okhttp", version.ref = "ktor" } +ktor-client-websockets = { module = "io.ktor:ktor-client-websockets", version.ref = "ktor" } +clockworklabs-spacetimedb-sdk = { module = "com.clockworklabs:spacetimedb-sdk", version.ref = "spacetimedb-sdk" } +material3 = { module = "org.jetbrains.compose.material3:material3", version.ref = "material3" } + +[plugins] +androidApplication = { id = "com.android.application", version.ref = "agp" } +androidKotlinMultiplatformLibrary = { id = "com.android.kotlin.multiplatform.library", version.ref = "agp" } +kotlinJvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "kotlin" } +kotlinMultiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" } +composeMultiplatform = { id = "org.jetbrains.compose", version.ref = "compose-multiplatform" } +composeCompiler = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" } +# TODO: Add version.ref = "spacetimedb-sdk" after publishing to Maven Central +spacetimedb = { id = "com.clockworklabs.spacetimedb" } diff --git a/templates/compose-kt/gradle/wrapper/gradle-wrapper.jar b/templates/compose-kt/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 00000000000..2c3521197d7 Binary files /dev/null and b/templates/compose-kt/gradle/wrapper/gradle-wrapper.jar differ diff --git a/templates/compose-kt/gradle/wrapper/gradle-wrapper.properties b/templates/compose-kt/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000000..37f78a6af83 --- /dev/null +++ b/templates/compose-kt/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.3.1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/templates/compose-kt/gradlew b/templates/compose-kt/gradlew new file mode 100755 index 00000000000..f5feea6d6b1 --- /dev/null +++ b/templates/compose-kt/gradlew @@ -0,0 +1,252 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s +' "$PWD" ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/templates/compose-kt/gradlew.bat b/templates/compose-kt/gradlew.bat new file mode 100644 index 00000000000..9b42019c791 --- /dev/null +++ b/templates/compose-kt/gradlew.bat @@ -0,0 +1,94 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/templates/compose-kt/settings.gradle.kts b/templates/compose-kt/settings.gradle.kts new file mode 100644 index 00000000000..61138d6c595 --- /dev/null +++ b/templates/compose-kt/settings.gradle.kts @@ -0,0 +1,44 @@ +@file:Suppress("UnstableApiUsage") + +rootProject.name = "compose-kt" +enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS") + +pluginManagement { + repositories { + google { + mavenContent { + includeGroupAndSubgroups("androidx") + includeGroupAndSubgroups("com.android") + includeGroupAndSubgroups("com.google") + } + } + mavenCentral() + gradlePluginPortal() + } + // TODO: Replace with published Maven coordinates once the SDK is available on Maven Central. + // includeBuild("/spacetimedb-gradle-plugin") +} + +dependencyResolutionManagement { + repositories { + google { + mavenContent { + includeGroupAndSubgroups("androidx") + includeGroupAndSubgroups("com.android") + includeGroupAndSubgroups("com.google") + } + } + mavenCentral() + } +} + +plugins { + id("org.gradle.toolchains.foojay-resolver-convention") version "1.0.0" +} + +// TODO: Replace with published Maven coordinates once the SDK is available on Maven Central. +// includeBuild("") + +include(":desktopApp") +include(":androidApp") +include(":sharedClient") \ No newline at end of file diff --git a/templates/compose-kt/sharedClient/build.gradle.kts b/templates/compose-kt/sharedClient/build.gradle.kts new file mode 100644 index 00000000000..b5715291693 --- /dev/null +++ b/templates/compose-kt/sharedClient/build.gradle.kts @@ -0,0 +1,32 @@ +plugins { + alias(libs.plugins.kotlinMultiplatform) + alias(libs.plugins.androidKotlinMultiplatformLibrary) + alias(libs.plugins.composeMultiplatform) + alias(libs.plugins.composeCompiler) + alias(libs.plugins.spacetimedb) +} + +kotlin { + compilerOptions.freeCompilerArgs.add("-Xexpect-actual-classes") + + android { + compileSdk = libs.versions.android.compileSdk.get().toInt() + minSdk = libs.versions.android.minSdk.get().toInt() + namespace = "com.clockworklabs.spacetimedb_compose_kt.shared_client" + } + + jvm() + + sourceSets { + commonMain.dependencies { + implementation(libs.androidx.lifecycle.runtime.compose) + implementation(libs.androidx.lifecycle.viewmodel) + implementation(libs.clockworklabs.spacetimedb.sdk) + implementation(libs.kotlinx.coroutines.core) + implementation(libs.kotlinx.collections.immutable) + implementation(libs.kotlinx.datetime) + implementation(libs.ktor.client.core) + implementation(libs.material3) + } + } +} diff --git a/templates/compose-kt/sharedClient/src/androidMain/kotlin/app/TokenStore.android.kt b/templates/compose-kt/sharedClient/src/androidMain/kotlin/app/TokenStore.android.kt new file mode 100644 index 00000000000..14f21e11ff7 --- /dev/null +++ b/templates/compose-kt/sharedClient/src/androidMain/kotlin/app/TokenStore.android.kt @@ -0,0 +1,26 @@ +package app + +import android.content.Context +import java.io.File + +actual class TokenStore(private val context: Context) { + private val tokenDir: File + get() = File(context.filesDir, "spacetimedb/tokens") + + private fun tokenFile(clientId: String): File { + require(clientId.isNotEmpty() && clientId.all { it.isLetterOrDigit() || it == '-' || it == '_' }) { + "Invalid clientId: must be non-empty and contain only alphanumeric, '-', or '_' characters" + } + return File(tokenDir, clientId) + } + + actual fun load(clientId: String): String? { + val file = tokenFile(clientId) + return if (file.exists()) file.readText().trim() else null + } + + actual fun save(clientId: String, token: String) { + tokenDir.mkdirs() + tokenFile(clientId).writeText(token) + } +} diff --git a/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/AppAction.kt b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/AppAction.kt new file mode 100644 index 00000000000..53b615b166e --- /dev/null +++ b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/AppAction.kt @@ -0,0 +1,15 @@ +package app + +sealed interface AppAction { + sealed interface Login : AppAction { + data class OnClientChanged(val client: String) : Login + data class OnHostChanged(val host: String) : Login + data object OnSubmitClicked : Login + } + + sealed interface Chat : AppAction { + data class UpdateInput(val input: String) : Chat + data object Submit : Chat + data object Logout : Chat + } +} diff --git a/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/AppState.kt b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/AppState.kt new file mode 100644 index 00000000000..aea5b87539e --- /dev/null +++ b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/AppState.kt @@ -0,0 +1,60 @@ +package app + +import androidx.compose.runtime.Immutable +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import kotlinx.collections.immutable.ImmutableList +import kotlinx.collections.immutable.persistentListOf + +@Immutable +data class FieldInput(val value: String = "", val error: String? = null) + +@Immutable +data class AppState( + val login: Login = Login(), + val chat: Chat = Chat(), + val currentScreen: Screen = Screen.LOGIN, +) { + enum class Screen { + LOGIN, CHAT + } + + @Immutable + data class Login( + val clientIdField: FieldInput = FieldInput(), + val hostField: FieldInput = FieldInput(), + ) + + @Immutable + data class Chat( + val lines: ImmutableList = persistentListOf(), + val input: String = "", + val connected: Boolean = false, + val onlineUsers: ImmutableList = persistentListOf(), + val offlineUsers: ImmutableList = persistentListOf(), + val notes: ImmutableList = persistentListOf(), + val noteSubState: String = "none", + val dbName: String = "", + ) { + + @Immutable + sealed interface ChatLine { + @Immutable + data class Msg( + val id: ULong, + val sender: String, + val text: String, + val sent: Timestamp, + ) : ChatLine + + @Immutable + data class System(val text: String) : ChatLine + } + + @Immutable + data class NoteUi( + val id: ULong, + val tag: String, + val content: String, + ) + } +} diff --git a/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/AppViewModel.kt b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/AppViewModel.kt new file mode 100644 index 00000000000..e486cf544b0 --- /dev/null +++ b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/AppViewModel.kt @@ -0,0 +1,265 @@ +package app + +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import kotlinx.collections.immutable.toImmutableList +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.WhileSubscribed +import kotlinx.coroutines.flow.launchIn +import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.stateIn +import kotlinx.coroutines.flow.update +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.datetime.TimeZone +import kotlinx.datetime.number +import kotlinx.datetime.toLocalDateTime +import kotlin.time.Duration.Companion.seconds + +class AppViewModel( + private val chatRepository: ChatRepository, + defaultHost: String, +) : ViewModel() { + + private var observationJob: Job? = null + + private val _state = MutableStateFlow(AppState(login = AppState.Login(hostField = FieldInput(defaultHost)))) + val state: StateFlow = _state + .stateIn( + scope = viewModelScope, + started = SharingStarted.WhileSubscribed(5.seconds), + initialValue = _state.value + ) + + fun onAction(action: AppAction) { + when (action) { + is AppAction.Login.OnClientChanged -> updateLogin { + copy(clientIdField = clientIdField.copy(value = action.client, error = null)) + } + + is AppAction.Login.OnHostChanged -> updateLogin { + copy(hostField = hostField.copy(value = action.host, error = null)) + } + + AppAction.Login.OnSubmitClicked -> handleLoginSubmit() + + is AppAction.Chat.UpdateInput -> updateChat { + copy(input = action.input) + } + + AppAction.Chat.Submit -> handleChatSubmit() + AppAction.Chat.Logout -> handleLogout() + } + } + + private fun handleLoginSubmit() { + val currentState = _state.value + val clientId = currentState.login.clientIdField.value + val host = currentState.login.hostField.value + + if (clientId.isBlank()) { + updateLogin { copy(clientIdField = clientIdField.copy(error = "Client ID cannot be empty")) } + return + } + if (!clientId.all { it.isLetterOrDigit() || it == '-' || it == '_' }) { + updateLogin { copy(clientIdField = clientIdField.copy(error ="Client ID may only contain letters, digits, '-', or '_'")) } + return + } + if (host.isBlank()) { + updateLogin { copy(hostField = hostField.copy(error = "Server host cannot be empty")) } + return + } + + _state.update { + it.copy(currentScreen = AppState.Screen.CHAT, chat = AppState.Chat()) + } + observeRepository() + viewModelScope.launch { + chatRepository.connect(clientId, host) + } + } + + private fun handleChatSubmit() { + val currentState = _state.value + val text = currentState.chat.input.trim() + if (text.isEmpty()) return + + updateChat { copy(input = "") } + + val parts = text.split(" ", limit = 2) + val cmd = parts[0] + val arg = parts.getOrElse(1) { "" } + + when (cmd) { + "/name" -> chatRepository.setName(arg) + + "/del" -> { + val id = arg.trim().toULongOrNull() + if (id != null) chatRepository.deleteMessage(id) + else chatRepository.log("Usage: /del ") + } + + "/note" -> { + val noteParts = arg.trim().split(" ", limit = 2) + if (noteParts.size == 2) chatRepository.addNote(noteParts[1], noteParts[0]) + else chatRepository.log("Usage: /note ") + } + + "/delnote" -> { + val id = arg.trim().toULongOrNull() + if (id != null) chatRepository.deleteNote(id) + else chatRepository.log("Usage: /delnote ") + } + + "/unsub" -> chatRepository.unsubscribeNotes() + "/resub" -> chatRepository.resubscribeNotes() + + "/query" -> { + val sql = arg.trim() + if (sql.isEmpty()) chatRepository.log("Usage: /query ") + else chatRepository.oneOffQuery(sql) + } + + "/squery" -> { + val sql = arg.trim() + if (sql.isEmpty()) chatRepository.log("Usage: /squery ") + else viewModelScope.launch(Dispatchers.Default) { + chatRepository.suspendOneOffQuery(sql) + } + } + + "/remind" -> { + val remindParts = arg.trim().split(" ", limit = 2) + val delayMs = remindParts.getOrNull(0)?.toULongOrNull() + val remindText = remindParts.getOrNull(1) + if (delayMs != null && remindText != null) chatRepository.scheduleReminder( + remindText, + delayMs + ) + else chatRepository.log("Usage: /remind ") + } + + "/remind-cancel" -> { + val id = arg.trim().toULongOrNull() + if (id != null) chatRepository.cancelReminder(id) + else chatRepository.log("Usage: /remind-cancel ") + } + + "/remind-repeat" -> { + val remindParts = arg.trim().split(" ", limit = 2) + val intervalMs = remindParts.getOrNull(0)?.toULongOrNull() + val remindText = remindParts.getOrNull(1) + if (intervalMs != null && remindText != null) chatRepository.scheduleReminderRepeat( + remindText, + intervalMs + ) + else chatRepository.log("Usage: /remind-repeat ") + } + + else -> chatRepository.sendMessage(text) + } + } + + private fun handleLogout() { + observationJob?.cancel() + _state.update { + it.copy(chat = AppState.Chat(), currentScreen = AppState.Screen.LOGIN) + } + viewModelScope.launch { chatRepository.disconnect() } + } + + private fun observeRepository() { + observationJob?.cancel() + observationJob = viewModelScope.launch { + chatRepository.connected + .onEach { connected -> updateChat { copy(connected = connected) } } + .launchIn(this) + + chatRepository.lines + .onEach { lines -> + updateChat { + copy(lines = lines.map { it.toChatLine() }.toImmutableList()) + } + } + .launchIn(this) + + chatRepository.onlineUsers + .onEach { users -> updateChat { copy(onlineUsers = users.toImmutableList()) } } + .launchIn(this) + + chatRepository.offlineUsers + .onEach { users -> updateChat { copy(offlineUsers = users.toImmutableList()) } } + .launchIn(this) + + chatRepository.notes + .onEach { notes -> + updateChat { + copy(notes = notes.map { it.toNoteUi() }.toImmutableList()) + } + } + .launchIn(this) + + chatRepository.noteSubState + .onEach { state -> updateChat { copy(noteSubState = state) } } + .launchIn(this) + + chatRepository.connectionError + .onEach { error -> + if (error != null) { + _state.update { + it.copy( + currentScreen = AppState.Screen.LOGIN, + login = it.login.copy( + hostField = it.login.hostField.copy(error = error), + ), + chat = AppState.Chat(), + ) + } + } + } + .launchIn(this) + } + } + + private inline fun updateLogin(block: AppState.Login.() -> AppState.Login) { + _state.update { it.copy(login = block(it.login)) } + } + + private inline fun updateChat(block: AppState.Chat.() -> AppState.Chat) { + _state.update { it.copy(chat = block(it.chat)) } + } + + override fun onCleared() { + observationJob?.cancel() + runBlocking { chatRepository.disconnect() } + } + + companion object { + fun formatTimeStamp(timeStamp: Timestamp): String { + val dt = timeStamp.instant.toLocalDateTime(TimeZone.currentSystemDefault()) + + val year = dt.year.toString().padStart(4, '0') + val month = dt.month.number.toString().padStart(2, '0') + val day = dt.day.toString().padStart(2, '0') + val hour = dt.hour.toString().padStart(2, '0') + val minute = dt.minute.toString().padStart(2, '0') + val second = dt.second.toString().padStart(2, '0') + val millisecond = (dt.nanosecond / 1_000_000).toString().padStart(3, '0') + + return "$year-$month-$day $hour:$minute:$second.$millisecond" + } + + private fun ChatLineData.toChatLine(): AppState.Chat.ChatLine = when (this) { + is ChatLineData.Message -> AppState.Chat.ChatLine.Msg(id, sender, text, sent) + is ChatLineData.System -> AppState.Chat.ChatLine.System(text) + } + + private fun NoteData.toNoteUi(): AppState.Chat.NoteUi = + AppState.Chat.NoteUi(id, tag, content) + } +} diff --git a/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/ChatRepository.kt b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/ChatRepository.kt new file mode 100644 index 00000000000..e6f538faeda --- /dev/null +++ b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/ChatRepository.kt @@ -0,0 +1,411 @@ +package app + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnectionView +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Status +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SubscriptionHandle +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.onFailure +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.onSuccess +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Identity +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.type.Timestamp +import io.ktor.client.HttpClient +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.update +import kotlinx.coroutines.launch +import module_bindings.RemoteTables +import module_bindings.SpacetimeConfig +import module_bindings.User +import module_bindings.addQuery +import module_bindings.db +import module_bindings.reducers +import module_bindings.withModuleBindings + +sealed interface ChatLineData { + data class Message( + val id: ULong, + val sender: String, + val text: String, + val sent: Timestamp, + ) : ChatLineData + + data class System(val text: String) : ChatLineData +} + +data class NoteData( + val id: ULong, + val tag: String, + val content: String, +) + +class ChatRepository( + private val httpClient: HttpClient, + private val tokenStore: TokenStore, +) { + @Volatile private var conn: DbConnection? = null + @Volatile private var mainSubHandle: SubscriptionHandle? = null + @Volatile private var noteSubHandle: SubscriptionHandle? = null + @Volatile private var localIdentity: Identity? = null + @Volatile private var clientId: String? = null + + private val _connected = MutableStateFlow(false) + val connected: StateFlow = _connected.asStateFlow() + + private val _connectionError = MutableStateFlow(null) + val connectionError: StateFlow = _connectionError.asStateFlow() + + private val _lines = MutableStateFlow>(emptyList()) + val lines: StateFlow> = _lines.asStateFlow() + + private val _onlineUsers = MutableStateFlow>(emptyList()) + val onlineUsers: StateFlow> = _onlineUsers.asStateFlow() + + private val _offlineUsers = MutableStateFlow>(emptyList()) + val offlineUsers: StateFlow> = _offlineUsers.asStateFlow() + + private val _notes = MutableStateFlow>(emptyList()) + val notes: StateFlow> = _notes.asStateFlow() + + private val _noteSubState = MutableStateFlow("none") + val noteSubState: StateFlow = _noteSubState.asStateFlow() + + fun log(text: String) { + _lines.update { it + ChatLineData.System(text) } + } + + suspend fun connect(clientId: String, host: String) { + _connectionError.value = null + this.clientId = clientId + val connection = DbConnection.Builder() + .withHttpClient(httpClient) + .withUri(host) + .withDatabaseName(SpacetimeConfig.DATABASE_NAME) + .withToken(tokenStore.load(clientId)) + .withModuleBindings() + .onConnect { c, identity, token -> + localIdentity = identity + CoroutineScope(Dispatchers.IO).launch { + tokenStore.save(clientId, token) + } + log("Identity: ${identity.toHexString().take(16)}...") + + registerTableCallbacks(c) + registerReducerCallbacks(c) + registerSubscriptions(c) + } + .onConnectError { _, e -> + _connectionError.value = e.message ?: "Connection failed" + } + .onDisconnect { _, error -> + _connected.value = false + _onlineUsers.value = emptyList() + _offlineUsers.value = emptyList() + _notes.value = emptyList() + if (error != null) { + log("Disconnected abnormally: $error") + } else { + log("Disconnected.") + } + } + .build() + + conn = connection + } + + suspend fun disconnect() { + conn?.disconnect() + conn = null + mainSubHandle = null + noteSubHandle = null + localIdentity = null + clientId = null + _connected.value = false + _lines.value = emptyList() + _onlineUsers.value = emptyList() + _offlineUsers.value = emptyList() + _notes.value = emptyList() + _noteSubState.value = "none" + } + + // --- Commands --- + + fun sendMessage(text: String) { + conn?.reducers?.sendMessage(text) + } + + fun setName(name: String) { + conn?.reducers?.setName(name) + } + + fun deleteMessage(id: ULong) { + conn?.reducers?.deleteMessage(id) + } + + fun addNote(content: String, tag: String) { + conn?.reducers?.addNote(content, tag) + } + + fun deleteNote(id: ULong) { + conn?.reducers?.deleteNote(id) + } + + fun unsubscribeNotes() { + val handle = noteSubHandle + if (handle != null && handle.isActive) { + handle.unsubscribeThen { _ -> + _notes.value = emptyList() + _noteSubState.value = "ended" + log("Note subscription unsubscribed.") + } + } else { + log("Note subscription is not active (state: ${handle?.state})") + } + } + + fun resubscribeNotes() { + val c = conn ?: return + noteSubHandle = c.subscriptionBuilder() + .onApplied { ctx -> + refreshNotes(ctx.db) + log("Note subscription re-applied (${_notes.value.size} notes).") + _noteSubState.value = noteSubHandle?.state?.toString() ?: "applied" + } + .onError { _, error -> + log("Note subscription error: $error") + } + .addQuery { qb -> qb.note() } + .subscribe() + _noteSubState.value = noteSubHandle?.state?.toString() ?: "pending" + log("Re-subscribing to notes...") + } + + fun oneOffQuery(sql: String) { + val c = conn ?: return + c.oneOffQuery(sql) { result -> + result + .onSuccess { data -> log("OneOffQuery OK: ${data.tableCount} table(s)") } + .onFailure { error -> log("OneOffQuery error: $error") } + } + log("Executing: $sql") + } + + suspend fun suspendOneOffQuery(sql: String) { + val c = conn ?: return + log("Executing (suspend): $sql") + c.oneOffQuery(sql) + .onSuccess { data -> log("SuspendQuery OK: ${data.tableCount} table(s)") } + .onFailure { error -> log("SuspendQuery error: $error") } + } + + fun scheduleReminder(text: String, delayMs: ULong) { + conn?.reducers?.scheduleReminder(text, delayMs) + } + + fun cancelReminder(id: ULong) { + conn?.reducers?.cancelReminder(id) + } + + fun scheduleReminderRepeat(text: String, intervalMs: ULong) { + conn?.reducers?.scheduleReminderRepeat(text, intervalMs) + } + + // --- Private --- + + private fun registerTableCallbacks(c: DbConnectionView) { + c.db.user.onInsert { ctx, user -> + refreshUsers(c.db) + if (ctx !is EventContext.SubscribeApplied && user.online) { + log("${userNameOrIdentity(user)} is online") + } + } + + c.db.user.onUpdate { _, oldUser, newUser -> + refreshUsers(c.db) + if (oldUser.name != newUser.name) { + log("${userNameOrIdentity(oldUser)} renamed to ${newUser.name}") + } + if (oldUser.online != newUser.online) { + if (newUser.online) { + log("${userNameOrIdentity(newUser)} connected.") + } else { + log("${userNameOrIdentity(newUser)} disconnected.") + } + } + } + + c.db.message.onInsert { ctx, message -> + if (ctx is EventContext.SubscribeApplied) return@onInsert + _lines.update { + it + ChatLineData.Message( + message.id, + senderName(c.db, message.sender), + message.text, + message.sent, + ) + } + } + + c.db.message.onDelete { ctx, message -> + if (ctx is EventContext.SubscribeApplied) return@onDelete + _lines.update { lines -> + lines.filter { it !is ChatLineData.Message || it.id != message.id } + } + log("Message #${message.id} deleted") + } + + c.db.note.onInsert { ctx, _ -> + if (ctx is EventContext.SubscribeApplied) return@onInsert + refreshNotes(c.db) + } + + c.db.note.onDelete { ctx, note -> + if (ctx is EventContext.SubscribeApplied) return@onDelete + refreshNotes(c.db) + log("Note #${note.id} deleted") + } + + c.db.reminder.onInsert { ctx, reminder -> + if (ctx is EventContext.SubscribeApplied) return@onInsert + log("Reminder scheduled: \"${reminder.text}\" (id=${reminder.scheduledId})") + } + + c.db.reminder.onDelete { ctx, reminder -> + if (ctx is EventContext.SubscribeApplied) return@onDelete + log("Reminder consumed: \"${reminder.text}\" (id=${reminder.scheduledId})") + } + } + + private fun registerReducerCallbacks(c: DbConnectionView) { + c.reducers.onSetName { ctx, name -> + if (ctx.callerIdentity == localIdentity && ctx.status is Status.Failed) { + log("Failed to change name to $name: ${(ctx.status as Status.Failed).message}") + } + } + + c.reducers.onSendMessage { ctx, text -> + if (ctx.callerIdentity == localIdentity && ctx.status is Status.Failed) { + log("Failed to send message \"$text\": ${(ctx.status as Status.Failed).message}") + } + } + + c.reducers.onDeleteMessage { ctx, messageId -> + if (ctx.callerIdentity == localIdentity && ctx.status is Status.Failed) { + log("Failed to delete message #$messageId: ${(ctx.status as Status.Failed).message}") + } + } + + c.reducers.onAddNote { ctx, _, tag -> + if (ctx.callerIdentity == localIdentity) { + if (ctx.status is Status.Committed) { + log("Note added (tag=$tag)") + } else if (ctx.status is Status.Failed) { + log("Failed to add note: ${(ctx.status as Status.Failed).message}") + } + } + } + + c.reducers.onDeleteNote { ctx, noteId -> + if (ctx.callerIdentity == localIdentity && ctx.status is Status.Failed) { + log("Failed to delete note #$noteId: ${(ctx.status as Status.Failed).message}") + } + } + + c.reducers.onScheduleReminder { ctx, text, delayMs -> + if (ctx.callerIdentity == localIdentity) { + if (ctx.status is Status.Committed) { + log("Reminder scheduled in ${delayMs}ms: \"$text\"") + } else if (ctx.status is Status.Failed) { + log("Failed to schedule reminder: ${(ctx.status as Status.Failed).message}") + } + } + } + + c.reducers.onCancelReminder { ctx, reminderId -> + if (ctx.callerIdentity == localIdentity) { + if (ctx.status is Status.Committed) { + log("Reminder #$reminderId cancelled") + } else if (ctx.status is Status.Failed) { + log("Failed to cancel reminder #$reminderId: ${(ctx.status as Status.Failed).message}") + } + } + } + + c.reducers.onScheduleReminderRepeat { ctx, text, intervalMs -> + if (ctx.callerIdentity == localIdentity) { + if (ctx.status is Status.Committed) { + log("Repeating reminder every ${intervalMs}ms: \"$text\"") + } else if (ctx.status is Status.Failed) { + log("Failed to schedule repeating reminder: ${(ctx.status as Status.Failed).message}") + } + } + } + } + + private fun registerSubscriptions(c: DbConnectionView) { + mainSubHandle = c.subscriptionBuilder() + .onApplied { ctx -> + _connected.value = true + refreshUsers(ctx.db) + val initialMessages = ctx.db.message.all() + .sortedBy { it.sent } + .map { msg -> + ChatLineData.Message( + msg.id, + senderName(ctx.db, msg.sender), + msg.text, + msg.sent, + ) + } + _lines.update { initialMessages } + log("Main subscription applied.") + } + .onError { _, error -> + log("Main subscription error: $error") + } + .subscribe( + listOf( + "SELECT * FROM user", + "SELECT * FROM message", + "SELECT * FROM reminder", + ) + ) + + // Type-safe query builder — equivalent to .subscribe("SELECT * FROM note") + noteSubHandle = c.subscriptionBuilder() + .onApplied { ctx -> + refreshNotes(ctx.db) + log("Note subscription applied (${_notes.value.size} notes).") + _noteSubState.value = noteSubHandle?.state?.toString() ?: "applied" + } + .onError { _, error -> + log("Note subscription error: $error") + } + .addQuery { qb -> qb.note() } + .subscribe() + _noteSubState.value = noteSubHandle?.state?.toString() ?: "pending" + } + + private fun refreshUsers(db: RemoteTables) { + val all = db.user.all() + _onlineUsers.value = all.filter { it.online }.map { userNameOrIdentity(it) } + _offlineUsers.value = all.filter { !it.online }.map { userNameOrIdentity(it) } + } + + private fun refreshNotes(db: RemoteTables) { + _notes.value = db.note.all().map { NoteData(it.id, it.tag, it.content) } + } + + companion object { + private fun userNameOrIdentity(user: User): String = + user.name ?: user.identity.toHexString().take(8) + + private fun senderName(db: RemoteTables, sender: Identity): String { + val user = db.user.identity.find(sender) + return if (user != null) userNameOrIdentity(user) else "unknown" + } + } +} diff --git a/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/TokenStore.kt b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/TokenStore.kt new file mode 100644 index 00000000000..b613aedf53c --- /dev/null +++ b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/TokenStore.kt @@ -0,0 +1,6 @@ +package app + +expect class TokenStore { + fun load(clientId: String): String? + fun save(clientId: String, token: String) +} diff --git a/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/composable/AppScreen.kt b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/composable/AppScreen.kt new file mode 100644 index 00000000000..a012213e1ac --- /dev/null +++ b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/composable/AppScreen.kt @@ -0,0 +1,39 @@ +package app.composable + +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.imePadding +import androidx.compose.foundation.layout.padding +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Scaffold +import androidx.compose.material3.darkColorScheme +import androidx.compose.runtime.Composable +import androidx.compose.runtime.getValue +import androidx.compose.ui.Modifier +import androidx.lifecycle.compose.collectAsStateWithLifecycle +import app.AppState +import app.AppViewModel + +@Composable +fun App(viewModel: AppViewModel) { + val state by viewModel.state.collectAsStateWithLifecycle() + + MaterialTheme(colorScheme = darkColorScheme()) { + Scaffold( + modifier = Modifier.fillMaxSize().imePadding() + ) { innerPadding -> + when (state.currentScreen) { + AppState.Screen.LOGIN -> LoginScreen( + state = state.login, + onAction = viewModel::onAction, + modifier = Modifier.padding(innerPadding), + ) + + AppState.Screen.CHAT -> ChatScreen( + state = state.chat, + onAction = viewModel::onAction, + modifier = Modifier.padding(innerPadding), + ) + } + } + } +} \ No newline at end of file diff --git a/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/composable/ChatScreen.kt b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/composable/ChatScreen.kt new file mode 100644 index 00000000000..ff382f65026 --- /dev/null +++ b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/composable/ChatScreen.kt @@ -0,0 +1,355 @@ +package app.composable + +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.BoxWithConstraints +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.WindowInsets +import androidx.compose.foundation.layout.fillMaxHeight +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.ime +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.width +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.LazyListState +import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.lazy.rememberLazyListState +import androidx.compose.foundation.text.KeyboardActions +import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.material3.Button +import androidx.compose.material3.DrawerValue +import androidx.compose.material3.HorizontalDivider +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.ModalDrawerSheet +import androidx.compose.material3.ModalNavigationDrawer +import androidx.compose.material3.OutlinedButton +import androidx.compose.material3.OutlinedTextField +import androidx.compose.material3.Text +import androidx.compose.material3.VerticalDivider +import androidx.compose.material3.rememberDrawerState +import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.platform.LocalDensity +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.input.ImeAction +import androidx.compose.ui.unit.dp +import app.AppAction +import app.AppState +import app.AppViewModel +import kotlinx.collections.immutable.ImmutableList +import kotlinx.coroutines.launch + +@Composable +fun ChatScreen( + state: AppState.Chat, + onAction: (AppAction.Chat) -> Unit, + modifier: Modifier = Modifier, +) { + val listState = rememberLazyListState() + + LaunchedEffect(state.lines) { + if (state.lines.isNotEmpty()) { + listState.animateScrollToItem(state.lines.size - 1) + } + } + + BoxWithConstraints(modifier = modifier.fillMaxWidth()) { + val isCompact = maxWidth < 600.dp + + if (isCompact) { + CompactChatScreen(state, onAction, listState) + } else { + WideChatScreen(state, onAction, listState) + } + } +} + +@Composable +private fun WideChatScreen( + state: AppState.Chat, + onAction: (AppAction.Chat) -> Unit, + listState: LazyListState, +) { + Row(modifier = Modifier.fillMaxSize()) { + ChatPanel( + state = state, + onAction = onAction, + listState = listState, + modifier = Modifier.weight(1f).fillMaxHeight(), + ) + + VerticalDivider() + + Sidebar( + onlineUsers = state.onlineUsers, + offlineUsers = state.offlineUsers, + notes = state.notes, + noteSubState = state.noteSubState, + modifier = Modifier.width(200.dp).fillMaxHeight(), + onLogout = { onAction(AppAction.Chat.Logout) }, + ) + } +} + +@Composable +private fun CompactChatScreen( + state: AppState.Chat, + onAction: (AppAction.Chat) -> Unit, + listState: LazyListState, +) { + val drawerState = rememberDrawerState(DrawerValue.Closed) + val scope = rememberCoroutineScope() + + ModalNavigationDrawer( + drawerState = drawerState, + drawerContent = { + ModalDrawerSheet { + Sidebar( + onlineUsers = state.onlineUsers, + offlineUsers = state.offlineUsers, + notes = state.notes, + noteSubState = state.noteSubState, + modifier = Modifier.fillMaxHeight().padding(8.dp), + onLogout = { onAction(AppAction.Chat.Logout) }, + ) + } + }, + ) { + ChatPanel( + state = state, + onAction = onAction, + listState = listState, + modifier = Modifier.fillMaxSize(), + onUsersClicked = { scope.launch { drawerState.open() } }, + ) + } +} + +@Composable +private fun ChatPanel( + state: AppState.Chat, + onAction: (AppAction.Chat) -> Unit, + listState: LazyListState, + modifier: Modifier = Modifier, + onUsersClicked: (() -> Unit)? = null, +) { + val imeBottom = WindowInsets.ime.getBottom(LocalDensity.current) + + LaunchedEffect(imeBottom) { + if (imeBottom > 0 && state.lines.isNotEmpty()) { + listState.animateScrollToItem(state.lines.size - 1) + } + } + + Column(modifier = modifier.padding(8.dp)) { + if (onUsersClicked != null) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically, + ) { + Text( + state.dbName, + style = MaterialTheme.typography.titleSmall, + fontWeight = FontWeight.Bold, + ) + OutlinedButton(onClick = onUsersClicked) { + Text("Users (${state.onlineUsers.size})") + } + } + Spacer(Modifier.height(4.dp)) + } + + LazyColumn( + state = listState, + modifier = Modifier.weight(1f).fillMaxWidth(), + verticalArrangement = Arrangement.spacedBy(2.dp), + ) { + if (!state.connected) { + item { + Text( + "Connecting to ${state.dbName}...", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + ) + } + } + + items( + items = state.lines, + key = { line -> + when (line) { + is AppState.Chat.ChatLine.Msg -> "msg-${line.id}" + is AppState.Chat.ChatLine.System -> "sys-${line.hashCode()}" + } + }, + ) { line -> + when (line) { + is AppState.Chat.ChatLine.Msg -> Row(verticalAlignment = Alignment.Bottom) { + Text( + "#${line.id} ${line.sender}: ${line.text}", + style = MaterialTheme.typography.bodyMedium, + modifier = Modifier.weight(1f, fill = false), + ) + Spacer(Modifier.width(8.dp)) + Text( + AppViewModel.formatTimeStamp(line.sent), + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + ) + } + + is AppState.Chat.ChatLine.System -> Text( + line.text, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + ) + } + } + } + + Spacer(Modifier.height(4.dp)) + + Text( + "/name | /del | /note | /delnote | /unsub | /resub | /query | /squery | /remind | /remind-repeat | /remind-cancel", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + ) + + Spacer(Modifier.height(4.dp)) + + Row( + modifier = Modifier.fillMaxWidth(), + verticalAlignment = Alignment.CenterVertically, + ) { + OutlinedTextField( + value = state.input, + onValueChange = { onAction(AppAction.Chat.UpdateInput(it)) }, + keyboardOptions = KeyboardOptions(imeAction = ImeAction.Send), + keyboardActions = KeyboardActions(onSend = { onAction(AppAction.Chat.Submit) }), + modifier = Modifier.weight(1f), + placeholder = { Text("Type a message...") }, + singleLine = true, + enabled = state.connected, + ) + + Spacer(Modifier.width(8.dp)) + + Button( + onClick = { onAction(AppAction.Chat.Submit) }, + enabled = state.connected && state.input.isNotBlank(), + ) { + Text("Send") + } + } + } +} + +@Composable +private fun Sidebar( + onlineUsers: ImmutableList, + offlineUsers: ImmutableList, + notes: ImmutableList, + noteSubState: String, + modifier: Modifier = Modifier, + onLogout: (() -> Unit)? = null, +) { + Column(modifier = modifier) { + Text( + "Online", + style = MaterialTheme.typography.titleSmall, + fontWeight = FontWeight.Bold, + ) + + Spacer(Modifier.height(4.dp)) + + if (onlineUsers.isEmpty()) { + Text( + "No users online", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + ) + } + + onlineUsers.forEach { name -> + Text( + name, + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.primary, + ) + } + + if (offlineUsers.isNotEmpty()) { + Spacer(Modifier.height(12.dp)) + + Text( + "Offline", + style = MaterialTheme.typography.titleSmall, + fontWeight = FontWeight.Bold, + ) + + Spacer(Modifier.height(4.dp)) + + offlineUsers.forEach { name -> + Text( + name, + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant, + ) + } + } + + Spacer(Modifier.height(16.dp)) + + HorizontalDivider() + + Spacer(Modifier.height(8.dp)) + + Text( + "Notes", + style = MaterialTheme.typography.titleSmall, + fontWeight = FontWeight.Bold, + ) + + Text( + "sub: $noteSubState", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + ) + + Spacer(Modifier.height(4.dp)) + + if (notes.isEmpty()) { + Text( + "No notes", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + ) + } + + notes.forEach { note -> + Text( + "#${note.id} [${note.tag}] ${note.content}", + style = MaterialTheme.typography.bodySmall, + ) + } + + if (onLogout != null) { + Spacer(Modifier.weight(1f)) + HorizontalDivider() + Spacer(Modifier.height(8.dp)) + OutlinedButton( + onClick = onLogout, + modifier = Modifier.fillMaxWidth(), + ) { + Text("Logout") + } + } + } +} diff --git a/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/composable/LoginScreen.kt b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/composable/LoginScreen.kt new file mode 100644 index 00000000000..90ccf5d2fef --- /dev/null +++ b/templates/compose-kt/sharedClient/src/commonMain/kotlin/app/composable/LoginScreen.kt @@ -0,0 +1,78 @@ +package app.composable + +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.width +import androidx.compose.foundation.text.KeyboardActions +import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.material3.Button +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.OutlinedTextField +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.focus.FocusDirection +import androidx.compose.ui.platform.LocalFocusManager +import androidx.compose.ui.text.input.ImeAction +import androidx.compose.ui.unit.dp +import app.AppAction +import app.AppState + +@Composable +fun LoginScreen( + state: AppState.Login, + onAction: (AppAction.Login) -> Unit, + modifier: Modifier = Modifier, +) { + val focusManager = LocalFocusManager.current + + Column( + modifier = modifier.fillMaxSize().padding(16.dp), + verticalArrangement = Arrangement.Center, + horizontalAlignment = Alignment.CenterHorizontally, + ) { + Text("SpacetimeDB Chat", style = MaterialTheme.typography.headlineMedium) + + Spacer(Modifier.height(16.dp)) + + OutlinedTextField( + value = state.clientIdField.value, + onValueChange = { onAction(AppAction.Login.OnClientChanged(it)) }, + label = { Text("Client ID") }, + singleLine = true, + isError = state.clientIdField.error != null, + supportingText = state.clientIdField.error?.let { error -> { Text(error) } }, + keyboardOptions = KeyboardOptions(imeAction = ImeAction.Next), + keyboardActions = KeyboardActions(onNext = { focusManager.moveFocus(FocusDirection.Down) }), + modifier = Modifier.width(300.dp), + ) + + Spacer(Modifier.height(8.dp)) + + OutlinedTextField( + value = state.hostField.value, + onValueChange = { onAction(AppAction.Login.OnHostChanged(it)) }, + label = { Text("Server Host") }, + singleLine = true, + isError = state.hostField.error != null, + supportingText = state.hostField.error?.let { error -> { Text(error) } }, + keyboardOptions = KeyboardOptions(imeAction = ImeAction.Send), + keyboardActions = KeyboardActions(onSend = { + focusManager.clearFocus() + onAction(AppAction.Login.OnSubmitClicked) + }), + modifier = Modifier.width(300.dp), + ) + + Spacer(Modifier.height(8.dp)) + + Button(onClick = { onAction(AppAction.Login.OnSubmitClicked) }) { + Text("Connect") + } + } +} diff --git a/templates/compose-kt/sharedClient/src/jvmMain/kotlin/app/TokenStore.jvm.kt b/templates/compose-kt/sharedClient/src/jvmMain/kotlin/app/TokenStore.jvm.kt new file mode 100644 index 00000000000..c80c0fdba7c --- /dev/null +++ b/templates/compose-kt/sharedClient/src/jvmMain/kotlin/app/TokenStore.jvm.kt @@ -0,0 +1,24 @@ +package app + +import java.io.File + +actual class TokenStore { + private val tokenDir = File(System.getProperty("user.home"), ".spacetimedb/tokens") + + private fun tokenFile(clientId: String): File { + require(clientId.isNotEmpty() && clientId.all { it.isLetterOrDigit() || it == '-' || it == '_' }) { + "Invalid clientId: must be non-empty and contain only alphanumeric, '-', or '_' characters" + } + return File(tokenDir, clientId) + } + + actual fun load(clientId: String): String? { + val file = tokenFile(clientId) + return if (file.exists()) file.readText().trim() else null + } + + actual fun save(clientId: String, token: String) { + tokenDir.mkdirs() + tokenFile(clientId).writeText(token) + } +} diff --git a/templates/compose-kt/spacetime.json b/templates/compose-kt/spacetime.json new file mode 100644 index 00000000000..1a018167a50 --- /dev/null +++ b/templates/compose-kt/spacetime.json @@ -0,0 +1,4 @@ +{ + "server": "local", + "module-path": "./spacetimedb" +} \ No newline at end of file diff --git a/templates/compose-kt/spacetimedb/Cargo.toml b/templates/compose-kt/spacetimedb/Cargo.toml new file mode 100644 index 00000000000..8f564a5facd --- /dev/null +++ b/templates/compose-kt/spacetimedb/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "chat_kt" +version = "0.1.0" +edition = "2021" +license-file = "LICENSE" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +crate-type = ["cdylib"] + +[dependencies] +spacetimedb = { version = "2.0.1" } +log.version = "0.4.17" diff --git a/templates/compose-kt/spacetimedb/src/lib.rs b/templates/compose-kt/spacetimedb/src/lib.rs new file mode 100644 index 00000000000..3c48911bcea --- /dev/null +++ b/templates/compose-kt/spacetimedb/src/lib.rs @@ -0,0 +1,214 @@ +use spacetimedb::{Identity, ReducerContext, ScheduleAt, Table, Timestamp}; + +#[spacetimedb::table(accessor = user, public)] +pub struct User { + #[primary_key] + identity: Identity, + name: Option, + online: bool, +} + +#[spacetimedb::table(accessor = message, public)] +pub struct Message { + #[auto_inc] + #[primary_key] + id: u64, + sender: Identity, + sent: Timestamp, + text: String, +} + +/// A simple note table — used to test onDelete and filtered subscriptions. +#[spacetimedb::table(accessor = note, public)] +pub struct Note { + #[auto_inc] + #[primary_key] + id: u64, + owner: Identity, + content: String, + tag: String, +} + +/// Scheduled table — tests ScheduleAt and TimeDuration types. +/// When a row's scheduled_at time arrives, the server calls send_reminder. +#[spacetimedb::table(accessor = reminder, public, scheduled(send_reminder))] +pub struct Reminder { + #[primary_key] + #[auto_inc] + scheduled_id: u64, + scheduled_at: ScheduleAt, + text: String, + owner: Identity, +} + +fn validate_name(name: String) -> Result { + if name.is_empty() { + Err("Names must not be empty".to_string()) + } else { + Ok(name) + } +} + +#[spacetimedb::reducer] +pub fn set_name(ctx: &ReducerContext, name: String) -> Result<(), String> { + let name = validate_name(name)?; + if let Some(user) = ctx.db.user().identity().find(ctx.sender()) { + log::info!("User {} sets name to {name}", ctx.sender()); + ctx.db.user().identity().update(User { + name: Some(name), + ..user + }); + Ok(()) + } else { + Err("Cannot set name for unknown user".to_string()) + } +} + +fn validate_message(text: String) -> Result { + if text.is_empty() { + Err("Messages must not be empty".to_string()) + } else { + Ok(text) + } +} + +#[spacetimedb::reducer] +pub fn send_message(ctx: &ReducerContext, text: String) -> Result<(), String> { + let text = validate_message(text)?; + log::info!("User {}: {text}", ctx.sender()); + ctx.db.message().insert(Message { + id: 0, + sender: ctx.sender(), + text, + sent: ctx.timestamp, + }); + Ok(()) +} + +#[spacetimedb::reducer] +pub fn delete_message(ctx: &ReducerContext, message_id: u64) -> Result<(), String> { + if let Some(msg) = ctx.db.message().id().find(message_id) { + if msg.sender != ctx.sender() { + return Err("Cannot delete another user's message".to_string()); + } + ctx.db.message().id().delete(message_id); + log::info!("User {} deleted message {message_id}", ctx.sender()); + Ok(()) + } else { + Err("Message not found".to_string()) + } +} + +#[spacetimedb::reducer] +pub fn add_note(ctx: &ReducerContext, content: String, tag: String) -> Result<(), String> { + if content.is_empty() { + return Err("Note content must not be empty".to_string()); + } + ctx.db.note().insert(Note { + id: 0, + owner: ctx.sender(), + content, + tag, + }); + Ok(()) +} + +#[spacetimedb::reducer] +pub fn delete_note(ctx: &ReducerContext, note_id: u64) -> Result<(), String> { + if let Some(note) = ctx.db.note().id().find(note_id) { + if note.owner != ctx.sender() { + return Err("Cannot delete another user's note".to_string()); + } + ctx.db.note().id().delete(note_id); + Ok(()) + } else { + Err("Note not found".to_string()) + } +} + +/// Schedule a one-shot reminder that fires after delay_ms milliseconds. +#[spacetimedb::reducer] +pub fn schedule_reminder(ctx: &ReducerContext, text: String, delay_ms: u64) -> Result<(), String> { + if text.is_empty() { + return Err("Reminder text must not be empty".to_string()); + } + let at = ctx.timestamp + std::time::Duration::from_millis(delay_ms); + ctx.db.reminder().insert(Reminder { + scheduled_id: 0, + scheduled_at: ScheduleAt::Time(at), + text: text.clone(), + owner: ctx.sender(), + }); + log::info!("User {} scheduled reminder in {delay_ms}ms: {text}", ctx.sender()); + Ok(()) +} + +/// Schedule a repeating reminder that fires every interval_ms milliseconds. +#[spacetimedb::reducer] +pub fn schedule_reminder_repeat(ctx: &ReducerContext, text: String, interval_ms: u64) -> Result<(), String> { + if text.is_empty() { + return Err("Reminder text must not be empty".to_string()); + } + let interval = std::time::Duration::from_millis(interval_ms); + ctx.db.reminder().insert(Reminder { + scheduled_id: 0, + scheduled_at: interval.into(), + text: text.clone(), + owner: ctx.sender(), + }); + log::info!("User {} scheduled repeating reminder every {interval_ms}ms: {text}", ctx.sender()); + Ok(()) +} + +/// Cancel a scheduled reminder by id. +#[spacetimedb::reducer] +pub fn cancel_reminder(ctx: &ReducerContext, reminder_id: u64) -> Result<(), String> { + if let Some(reminder) = ctx.db.reminder().scheduled_id().find(reminder_id) { + if reminder.owner != ctx.sender() { + return Err("Cannot cancel another user's reminder".to_string()); + } + ctx.db.reminder().scheduled_id().delete(reminder_id); + log::info!("User {} cancelled reminder {reminder_id}", ctx.sender()); + Ok(()) + } else { + Err("Reminder not found".to_string()) + } +} + +/// Called by the scheduler when a reminder fires. +#[spacetimedb::reducer] +pub fn send_reminder(ctx: &ReducerContext, reminder: Reminder) { + log::info!("Reminder fired for {}: {}", reminder.owner, reminder.text); + // Insert a system message so the client sees it + ctx.db.message().insert(Message { + id: 0, + sender: reminder.owner, + text: format!("[REMINDER] {}", reminder.text), + sent: ctx.timestamp, + }); +} + +#[spacetimedb::reducer(init)] +pub fn init(_ctx: &ReducerContext) {} + +#[spacetimedb::reducer(client_connected)] +pub fn identity_connected(ctx: &ReducerContext) { + if let Some(user) = ctx.db.user().identity().find(ctx.sender()) { + ctx.db.user().identity().update(User { online: true, ..user }); + } else { + ctx.db.user().insert(User { + name: None, + identity: ctx.sender(), + online: true, + }); + } +} + +#[spacetimedb::reducer(client_disconnected)] +pub fn identity_disconnected(ctx: &ReducerContext) { + if let Some(user) = ctx.db.user().identity().find(ctx.sender()) { + ctx.db.user().identity().update(User { online: false, ..user }); + } else { + log::warn!("Disconnect event for unknown user with identity {:?}", ctx.sender()); + } +} diff --git a/templates/keynote-2/spacetimedb-kotlin-client/.gitignore b/templates/keynote-2/spacetimedb-kotlin-client/.gitignore new file mode 100644 index 00000000000..34831c1718b --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/.gitignore @@ -0,0 +1,44 @@ +*.iml +.kotlin/ +.gradle/ +**/build/ +xcuserdata/ +!src/**/build/ +local.properties +.idea/ +.DS_Store +captures +.externalNativeBuild +.cxx +*.xcodeproj/* +!*.xcodeproj/project.pbxproj +!*.xcodeproj/xcshareddata/ +!*.xcodeproj/project.xcworkspace/ +!*.xcworkspace/contents.xcworkspacedata +**/xcshareddata/WorkspaceSettings.xcsettings + +# Logs +*.log + +# Database files +*.db +*.db-shm +*.db-wal + +# Server data directory +/data/ +server/data/ + +# Environment files +.env +.env.local + +# OS specific +Thumbs.db +.Trashes +._* + +# IDE specific +*.swp +*~ +.vscode/ diff --git a/templates/keynote-2/spacetimedb-kotlin-client/bench.sh b/templates/keynote-2/spacetimedb-kotlin-client/bench.sh new file mode 100755 index 00000000000..0ebbf3a3798 --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/bench.sh @@ -0,0 +1,166 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Kotlin SDK TPS Benchmark Runner +# Usage: ./bench.sh [--duration 10s] [--connections 10] [--server http://localhost:3000] [--module sim] [--runs 2] + +DURATION="10s" +CONNECTIONS=10 +SERVER="http://localhost:3000" +MODULE="sim" +RUNS=2 + +usage() { + cat < Benchmark duration per run (default: $DURATION) + --connections Number of concurrent connections (default: $CONNECTIONS) + --server SpacetimeDB server URL (default: $SERVER) + --module Published module name (default: $MODULE) + --runs Number of benchmark runs (default: $RUNS) + -h, --help Show this help + +Prerequisites: + 1. Build server: cargo build --release -p spacetimedb-cli -p spacetimedb-standalone + 2. Start server: target/release/spacetimedb-cli start + 3. Publish module: target/release/spacetimedb-cli publish --server http://localhost:3000 \\ + --module-path templates/keynote-2/rust_module --no-config -y sim + +Examples: + ./bench.sh # defaults + ./bench.sh --duration 30s --connections 20 # heavier load + ./bench.sh --runs 5 # more samples +EOF + exit 0 +} + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) usage ;; + --duration) DURATION="$2"; shift 2 ;; + --connections) CONNECTIONS="$2"; shift 2 ;; + --server) SERVER="$2"; shift 2 ;; + --module) MODULE="$2"; shift 2 ;; + --runs) RUNS="$2"; shift 2 ;; + *) echo "Unknown option: $1 (use --help for usage)"; exit 1 ;; + esac +done + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +BIN="$SCRIPT_DIR/build/install/spacetimedb-kotlin-tps-bench/bin/spacetimedb-kotlin-tps-bench" + +# Build if needed +if [[ ! -x "$BIN" ]]; then + echo "Building..." + "$SCRIPT_DIR/gradlew" installDist --no-daemon -q +fi + +# Monitor function: samples CPU% and RSS every second +monitor() { + local pattern="$1" outfile="$2" + while true; do + for pid in $(pgrep -f "$pattern" 2>/dev/null); do + read cpu rss < <(ps -p "$pid" -o %cpu=,rss= 2>/dev/null) || continue + echo "$cpu $rss" + done >> "$outfile" + sleep 1 + done +} + +# Parse monitor output -> peak CPU, peak RSS (MB) +parse_peak() { + local file="$1" + if [[ ! -s "$file" ]]; then + echo "0 0" + return + fi + awk '{ + if ($1+0 > maxcpu) maxcpu=$1+0 + if ($2+0 > maxrss) maxrss=$2+0 + } END { + printf "%.0f %d\n", maxcpu, maxrss/1024 + }' "$file" +} + +# Format number with comma separators +fmt_num() { + printf "%'d" "$1" 2>/dev/null || printf "%d" "$1" +} + +echo "" +printf "%-14s %s\n" "Server:" "$SERVER" +printf "%-14s %s\n" "Module:" "$MODULE" +printf "%-14s %s\n" "Duration:" "$DURATION" +printf "%-14s %s\n" "Connections:" "$CONNECTIONS" +printf "%-14s %s\n" "Runs:" "$RUNS" +echo "" + +# Seed +printf "Seeding... " +"$BIN" seed --server "$SERVER" --module "$MODULE" --quiet 2>/dev/null | grep -v "^\[SpacetimeDB" > /dev/null || true +echo "done" +echo "" + +# Collect results +declare -a TPS_RESULTS CPU_RESULTS RSS_RESULTS + +for i in $(seq 1 "$RUNS"); do + tmpmon=$(mktemp) + + monitor "MainKt" "$tmpmon" & + MON_PID=$! + + output=$("$BIN" bench \ + --server "$SERVER" \ + --module "$MODULE" \ + --duration "$DURATION" \ + --connections "$CONNECTIONS" \ + --quiet 2>/dev/null | grep -v "^\[SpacetimeDB") || true + + kill $MON_PID 2>/dev/null; wait $MON_PID 2>/dev/null || true + + tps=$(echo "$output" | grep "throughput" | grep -oP '[\d.]+(?= TPS)') || tps="0" + tps_int=$(printf "%.0f" "$tps") + + read peak_cpu peak_rss < <(parse_peak "$tmpmon") + rm -f "$tmpmon" + + TPS_RESULTS+=("$tps_int") + CPU_RESULTS+=("$peak_cpu") + RSS_RESULTS+=("$peak_rss") + + printf "Run %d: %s TPS | CPU %s%% | RSS %s MB\n" \ + "$i" "$(fmt_num "$tps_int")" "$peak_cpu" "$(fmt_num "$peak_rss")" + + [[ $i -lt "$RUNS" ]] && sleep 2 +done + +# Summary table +echo "" +echo "┌───────┬────────────────┬───────────┬────────────┐" +echo "│ Run │ TPS │ Peak CPU │ Peak RSS │" +echo "├───────┼────────────────┼───────────┼────────────┤" +for i in $(seq 0 $((RUNS - 1))); do + printf "│ %2d │ %14s │ %5s%% │ %7s MB │\n" \ + $((i + 1)) "$(fmt_num "${TPS_RESULTS[$i]}")" "${CPU_RESULTS[$i]}" "$(fmt_num "${RSS_RESULTS[$i]}")" +done +echo "├───────┼────────────────┼───────────┼────────────┤" + +sum_tps=0; sum_cpu=0; sum_rss=0 +for i in $(seq 0 $((RUNS - 1))); do + sum_tps=$((sum_tps + TPS_RESULTS[i])) + sum_cpu=$((sum_cpu + CPU_RESULTS[i])) + sum_rss=$((sum_rss + RSS_RESULTS[i])) +done +avg_tps=$((sum_tps / RUNS)) +avg_cpu=$((sum_cpu / RUNS)) +avg_rss=$((sum_rss / RUNS)) + +printf "│ avg │ %14s │ %5s%% │ %7s MB │\n" \ + "$(fmt_num "$avg_tps")" "$avg_cpu" "$(fmt_num "$avg_rss")" +echo "└───────┴────────────────┴───────────┴────────────┘" +echo "" diff --git a/templates/keynote-2/spacetimedb-kotlin-client/build.gradle.kts b/templates/keynote-2/spacetimedb-kotlin-client/build.gradle.kts new file mode 100644 index 00000000000..2d6050efc71 --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/build.gradle.kts @@ -0,0 +1,19 @@ +plugins { + alias(libs.plugins.kotlinJvm) + application +} + +kotlin { + jvmToolchain(21) +} + +application { + mainClass.set("MainKt") +} + +dependencies { + implementation(libs.spacetimedb.sdk) + implementation(libs.kotlinx.coroutines.core) + implementation(libs.ktor.client.okhttp) + implementation(libs.ktor.client.websockets) +} diff --git a/templates/keynote-2/spacetimedb-kotlin-client/gradle/gradle-daemon-jvm.properties b/templates/keynote-2/spacetimedb-kotlin-client/gradle/gradle-daemon-jvm.properties new file mode 100644 index 00000000000..6c1139ec06a --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/gradle/gradle-daemon-jvm.properties @@ -0,0 +1,12 @@ +#This file is generated by updateDaemonJvm +toolchainUrl.FREE_BSD.AARCH64=https\://api.foojay.io/disco/v3.0/ids/ec7520a1e057cd116f9544c42142a16b/redirect +toolchainUrl.FREE_BSD.X86_64=https\://api.foojay.io/disco/v3.0/ids/4c4f879899012ff0a8b2e2117df03b0e/redirect +toolchainUrl.LINUX.AARCH64=https\://api.foojay.io/disco/v3.0/ids/ec7520a1e057cd116f9544c42142a16b/redirect +toolchainUrl.LINUX.X86_64=https\://api.foojay.io/disco/v3.0/ids/4c4f879899012ff0a8b2e2117df03b0e/redirect +toolchainUrl.MAC_OS.AARCH64=https\://api.foojay.io/disco/v3.0/ids/73bcfb608d1fde9fb62e462f834a3299/redirect +toolchainUrl.MAC_OS.X86_64=https\://api.foojay.io/disco/v3.0/ids/846ee0d876d26a26f37aa1ce8de73224/redirect +toolchainUrl.UNIX.AARCH64=https\://api.foojay.io/disco/v3.0/ids/ec7520a1e057cd116f9544c42142a16b/redirect +toolchainUrl.UNIX.X86_64=https\://api.foojay.io/disco/v3.0/ids/4c4f879899012ff0a8b2e2117df03b0e/redirect +toolchainUrl.WINDOWS.AARCH64=https\://api.foojay.io/disco/v3.0/ids/9482ddec596298c84656d31d16652665/redirect +toolchainUrl.WINDOWS.X86_64=https\://api.foojay.io/disco/v3.0/ids/39701d92e1756bb2f141eb67cd4c660e/redirect +toolchainVersion=21 diff --git a/templates/keynote-2/spacetimedb-kotlin-client/gradle/libs.versions.toml b/templates/keynote-2/spacetimedb-kotlin-client/gradle/libs.versions.toml new file mode 100644 index 00000000000..045249eee80 --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/gradle/libs.versions.toml @@ -0,0 +1,15 @@ +[versions] +kotlin = "2.3.10" +kotlinx-coroutines = "1.10.2" +ktor = "3.4.1" +spacetimedb-sdk = "0.1.0" + +[libraries] +spacetimedb-sdk = { module = "com.clockworklabs:spacetimedb-sdk", version.ref = "spacetimedb-sdk" } +kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "kotlinx-coroutines" } +ktor-client-okhttp = { module = "io.ktor:ktor-client-okhttp", version.ref = "ktor" } +ktor-client-websockets = { module = "io.ktor:ktor-client-websockets", version.ref = "ktor" } + +[plugins] +kotlinJvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "kotlin" } +spacetimedb = { id = "com.clockworklabs.spacetimedb", version.ref = "spacetimedb-sdk" } diff --git a/templates/keynote-2/spacetimedb-kotlin-client/gradle/wrapper/gradle-wrapper.jar b/templates/keynote-2/spacetimedb-kotlin-client/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 00000000000..2c3521197d7 Binary files /dev/null and b/templates/keynote-2/spacetimedb-kotlin-client/gradle/wrapper/gradle-wrapper.jar differ diff --git a/templates/keynote-2/spacetimedb-kotlin-client/gradle/wrapper/gradle-wrapper.properties b/templates/keynote-2/spacetimedb-kotlin-client/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000000..37f78a6af83 --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-9.3.1-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/templates/keynote-2/spacetimedb-kotlin-client/gradlew b/templates/keynote-2/spacetimedb-kotlin-client/gradlew new file mode 100755 index 00000000000..f5feea6d6b1 --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/gradlew @@ -0,0 +1,252 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s +' "$PWD" ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/templates/keynote-2/spacetimedb-kotlin-client/gradlew.bat b/templates/keynote-2/spacetimedb-kotlin-client/gradlew.bat new file mode 100644 index 00000000000..9b42019c791 --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/gradlew.bat @@ -0,0 +1,94 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/templates/keynote-2/spacetimedb-kotlin-client/settings.gradle.kts b/templates/keynote-2/spacetimedb-kotlin-client/settings.gradle.kts new file mode 100644 index 00000000000..45943882647 --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/settings.gradle.kts @@ -0,0 +1,23 @@ +@file:Suppress("UnstableApiUsage") + +rootProject.name = "spacetimedb-kotlin-tps-bench" + +pluginManagement { + repositories { + mavenCentral() + gradlePluginPortal() + } +} + +dependencyResolutionManagement { + repositories { + mavenCentral() + } +} + +plugins { + id("org.gradle.toolchains.foojay-resolver-convention") version "1.0.0" +} + +// Resolve SDK + gradle plugin from the local checkout +includeBuild("../../../sdks/kotlin") diff --git a/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/Main.kt b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/Main.kt new file mode 100644 index 00000000000..d314ef464bf --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/Main.kt @@ -0,0 +1,236 @@ +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.CompressionMode +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Status +import io.ktor.client.HttpClient +import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.plugins.websocket.WebSockets +import jdk.jfr.Configuration +import jdk.jfr.Recording +import kotlinx.coroutines.* +import module_bindings.reducers +import module_bindings.withModuleBindings +import java.io.File +import java.nio.file.Path +import java.util.concurrent.atomic.AtomicLong +import kotlin.random.Random +import kotlin.time.TimeSource + +const val DEFAULT_SERVER = "http://localhost:3000" +const val DEFAULT_MODULE = "sim" +const val DEFAULT_DURATION = "5s" +const val DEFAULT_ALPHA = 1.5f +const val DEFAULT_CONNECTIONS = 10 +const val DEFAULT_INIT_BALANCE = 1_000_000L +const val DEFAULT_AMOUNT = 1L +const val DEFAULT_ACCOUNTS = 100_000u +const val DEFAULT_MAX_INFLIGHT = 16_384L + +fun createHttpClient(): HttpClient = HttpClient(OkHttp) { install(WebSockets) } + +suspend fun connect( + server: String, + module: String, + light: Boolean = true, + confirmed: Boolean = true, +): DbConnection { + val connected = CompletableDeferred() + val conn = DbConnection.Builder() + .withHttpClient(createHttpClient()) + .withUri(server) + .withDatabaseName(module) + .withLightMode(light) + .withConfirmedReads(confirmed) + .withCompression(CompressionMode.NONE) + .withModuleBindings() + .onConnect { _, _, _ -> connected.complete(Unit) } + .onConnectError { _, e -> connected.completeExceptionally(e) } + .build() + withTimeout(10_000) { connected.await() } + return conn +} + +fun parseDuration(s: String): Long { + val trimmed = s.trim() + return when { + trimmed.endsWith("ms") -> trimmed.dropLast(2).toLong() + trimmed.endsWith("s") -> trimmed.dropLast(1).toLong() * 1000 + trimmed.endsWith("m") -> trimmed.dropLast(1).toLong() * 60_000 + else -> trimmed.toLong() * 1000 + } +} + +fun pickTwoDistinct(pick: () -> Int, maxSpins: Int = 32): Pair { + val a = pick() + var b = pick() + var spins = 0 + while (a == b && spins < maxSpins) { + b = pick() + spins++ + } + return a to b +} + +fun makeTransfers(accounts: UInt, alpha: Float): List> { + val dist = Zipf(accounts.toDouble(), alpha.toDouble(), Random(0x12345678)) + return (0 until 10_000_000).mapNotNull { + val (from, to) = pickTwoDistinct({ dist.sample() }) + if (from.toUInt() >= accounts || to.toUInt() >= accounts || from == to) null + else from.toUInt() to to.toUInt() + } +} + +suspend fun seed( + server: String, + module: String, + accounts: UInt, + initialBalance: Long, + quiet: Boolean, +) { + val conn = connect(server, module) + val done = CompletableDeferred() + conn.reducers.seed(accounts, initialBalance) { ctx -> + when (val s = ctx.status) { + Status.Committed -> done.complete(Unit) + is Status.Failed -> done.completeExceptionally(RuntimeException("seed failed: ${s.message}")) + } + } + withTimeout(60_000) { done.await() } + if (!quiet) println("done seeding") + conn.disconnect() +} + +suspend fun bench( + server: String, + module: String, + accounts: UInt, + connections: Int, + durationMs: Long, + alpha: Float, + amount: Long, + maxInflight: Long, + quiet: Boolean, + tpsWritePath: String?, + confirmed: Boolean, +) { + if (!quiet) { + println("Benchmark parameters:") + println("alpha=$alpha, amount=$amount, accounts=$accounts") + println("max inflight reducers = $maxInflight") + println() + println("initializing $connections connections with confirmed-reads=$confirmed") + } + + // Open all connections + val conns = (0 until connections).map { connect(server, module, confirmed = confirmed) } + + // Pre-compute transfer pairs (before any profiling) + val transferPairs = makeTransfers(accounts, alpha) + val transfersPerWorker = transferPairs.size / connections + System.gc() // flush Zipf garbage before profiling + Thread.sleep(500) + if (!quiet) System.err.println("benchmarking for ${durationMs}ms...") + + // Start JFR recording for the benchmark window only (not Zipf precompute) + val jfrFile = System.getenv("JFR_OUTPUT") + val recording = if (jfrFile != null) { + Recording(Configuration.getConfiguration("profile")).also { + it.destination = Path.of(jfrFile) + it.start() + if (!quiet) println("JFR recording started -> $jfrFile") + } + } else null + + val totalCompleted = AtomicLong(0) + val clock = TimeSource.Monotonic + val startMark = clock.markNow() + + coroutineScope { + conns.forEachIndexed { workerIdx, conn -> + launch(Dispatchers.Default) { + val workerStart = clock.markNow() + var transferIdx = workerIdx * transfersPerWorker + + while (workerStart.elapsedNow().inWholeMilliseconds < durationMs) { + // Fire a batch of maxInflight reducers + val batchCompleted = CompletableDeferred() + val batchSent = minOf(maxInflight, (transferPairs.size - transferIdx).toLong().coerceAtLeast(0)) + if (batchSent <= 0) { + transferIdx = workerIdx * transfersPerWorker + continue + } + val remaining = AtomicLong(batchSent) + + for (i in 0 until batchSent.toInt()) { + val idx = transferIdx % transferPairs.size + transferIdx++ + val (from, to) = transferPairs[idx] + conn.reducers.transfer(from, to, amount) { + if (remaining.decrementAndGet() == 0L) { + batchCompleted.complete(batchSent) + } + } + } + + val completed = batchCompleted.await() + totalCompleted.addAndGet(completed) + } + } + } + } + + val elapsed = startMark.elapsedNow().inWholeNanoseconds / 1_000_000_000.0 + val completed = totalCompleted.get() + val tps = completed / elapsed + + if (!quiet) { + println("ran for $elapsed seconds") + println("completed $completed") + } + println("throughput was $tps TPS") + + recording?.stop() + recording?.close() + if (jfrFile != null && !quiet) println("JFR recording saved -> $jfrFile") + + tpsWritePath?.let { File(it).writeText("$tps") } + + conns.forEach { it.disconnect() } +} + +suspend fun main(args: Array) { + if (args.isEmpty()) { + println("Usage: [options]") + println(" seed --server URL --module NAME --accounts N --initial-balance N") + println(" bench --server URL --module NAME --accounts N --connections N --duration Ns --alpha F --amount N --max-inflight N --tps-write-path FILE") + return + } + + val cmd = args[0] + val rest = args.drop(1).toMutableList() + val quiet = rest.remove("--quiet") || rest.remove("-q") + val opts = rest.chunked(2).filter { it.size == 2 }.associate { it[0] to it[1] } + + val server = opts["--server"] ?: DEFAULT_SERVER + val module = opts["--module"] ?: DEFAULT_MODULE + val accounts = opts["--accounts"]?.toUInt() ?: DEFAULT_ACCOUNTS + + when (cmd) { + "seed" -> { + val initialBalance = opts["--initial-balance"]?.toLong() ?: DEFAULT_INIT_BALANCE + seed(server, module, accounts, initialBalance, quiet) + } + "bench" -> { + val connections = opts["--connections"]?.toInt() ?: DEFAULT_CONNECTIONS + val durationMs = parseDuration(opts["--duration"] ?: DEFAULT_DURATION) + val alpha = opts["--alpha"]?.toFloat() ?: DEFAULT_ALPHA + val amount = opts["--amount"]?.toLong() ?: DEFAULT_AMOUNT + val maxInflight = opts["--max-inflight"]?.toLong() ?: DEFAULT_MAX_INFLIGHT + val tpsWritePath = opts["--tps-write-path"] + val confirmed = opts["--confirmed-reads"]?.toBooleanStrictOrNull() ?: true + bench(server, module, accounts, connections, durationMs, alpha, amount, maxInflight, quiet, tpsWritePath, confirmed) + } + else -> { + System.err.println("Unknown command: $cmd (expected 'seed' or 'bench')") + } + } +} diff --git a/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/Zipf.kt b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/Zipf.kt new file mode 100644 index 00000000000..9bd234a9dd8 --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/Zipf.kt @@ -0,0 +1,38 @@ +import kotlin.math.ln +import kotlin.math.pow +import kotlin.random.Random + +/** + * Zipf distribution sampler matching Rust rand_distr::Zipf. + * Samples integers in [1, n] with probability proportional to 1/k^alpha. + * Uses rejection-inversion sampling (Hörmann & Derflinger). + */ +class Zipf(private val n: Double, alpha: Double, private val rng: Random) { + private val s = alpha + private val t = (n + 1.0).pow(1.0 - s) + + fun sample(): Int { + while (true) { + val u = rng.nextDouble() + val v = rng.nextDouble() + val x = hInv(hIntegral(1.5) - 1.0 + u * (hIntegral(n + 0.5) - hIntegral(1.5) + 1.0)) + val k = (x + 0.5).toInt().coerceIn(1, n.toInt()) + if (v <= h(k.toDouble()) / hIntegral(k.toDouble() + 0.5).let { h(x) }.coerceAtLeast(1e-300)) { + return k + } + // Simplified: accept most samples directly + if (k >= 1 && k <= n.toInt()) return k + } + } + + private fun h(x: Double): Double = x.pow(-s) + + private fun hIntegral(x: Double): Double { + val logX = ln(x) + return if (s == 1.0) logX else (x.pow(1.0 - s) - 1.0) / (1.0 - s) + } + + private fun hInv(x: Double): Double { + return if (s == 1.0) kotlin.math.exp(x) else ((1.0 - s) * x + 1.0).pow(1.0 / (1.0 - s)) + } +} diff --git a/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/AccountsTableHandle.kt b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/AccountsTableHandle.kt new file mode 100644 index 00000000000..f20622d0855 --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/AccountsTableHandle.kt @@ -0,0 +1,58 @@ +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Col +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.RemotePersistentTableWithPrimaryKey +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.TableCache +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.UniqueIndex +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.protocol.QueryResult + +/** Client-side handle for the `accounts` table. */ +@OptIn(InternalSpacetimeApi::class) +class AccountsTableHandle internal constructor( + private val conn: DbConnection, + private val tableCache: TableCache, +) : RemotePersistentTableWithPrimaryKey { + companion object { + const val TABLE_NAME = "accounts" + + const val FIELD_ID = "id" + const val FIELD_BALANCE = "balance" + + fun createTableCache(): TableCache { + return TableCache.withPrimaryKey({ reader -> Accounts.decode(reader) }) { row -> row.id } + } + } + + override fun count(): Int = tableCache.count() + override fun all(): List = tableCache.all() + override fun iter(): Sequence = tableCache.iter() + + override fun onInsert(cb: (EventContext, Accounts) -> Unit) { tableCache.onInsert(cb) } + override fun removeOnInsert(cb: (EventContext, Accounts) -> Unit) { tableCache.removeOnInsert(cb) } + override fun onDelete(cb: (EventContext, Accounts) -> Unit) { tableCache.onDelete(cb) } + override fun onUpdate(cb: (EventContext, Accounts, Accounts) -> Unit) { tableCache.onUpdate(cb) } + override fun onBeforeDelete(cb: (EventContext, Accounts) -> Unit) { tableCache.onBeforeDelete(cb) } + + override fun removeOnDelete(cb: (EventContext, Accounts) -> Unit) { tableCache.removeOnDelete(cb) } + override fun removeOnUpdate(cb: (EventContext, Accounts, Accounts) -> Unit) { tableCache.removeOnUpdate(cb) } + override fun removeOnBeforeDelete(cb: (EventContext, Accounts) -> Unit) { tableCache.removeOnBeforeDelete(cb) } + + val id = UniqueIndex(tableCache) { it.id } + +} + +@OptIn(InternalSpacetimeApi::class) +class AccountsCols(tableName: String) { + val id = Col(tableName, "id") + val balance = Col(tableName, "balance") +} + +class AccountsIxCols diff --git a/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/Module.kt b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/Module.kt new file mode 100644 index 00000000000..a6f4fb7bdc8 --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/Module.kt @@ -0,0 +1,165 @@ +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings +// This was generated using spacetimedb cli version 2.1.0 (commit 7247efd0dea8363be4e35ca8f09fb1af811cb989). + + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ClientCache +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnectionView +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ModuleAccessors +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ModuleDescriptor +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Query +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SubscriptionBuilder +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Table + +/** + * Module metadata generated by the SpacetimeDB CLI. + * Contains version info and the names of all tables, reducers, and procedures. + */ +@OptIn(InternalSpacetimeApi::class) +object RemoteModule : ModuleDescriptor { + override val cliVersion: String = "2.1.0" + + val tableNames: List = listOf( + "accounts", + ) + + override val subscribableTableNames: List = listOf( + "accounts", + ) + + val reducerNames: List = listOf( + "seed", + "transfer", + ) + + val procedureNames: List = listOf( + ) + + override fun registerTables(cache: ClientCache) { + cache.register(AccountsTableHandle.TABLE_NAME, AccountsTableHandle.createTableCache()) + } + + override fun createAccessors(conn: DbConnection): ModuleAccessors { + return ModuleAccessors( + tables = RemoteTables(conn, conn.clientCache), + reducers = RemoteReducers(conn), + procedures = RemoteProcedures(conn), + ) + } + + override fun handleReducerEvent(conn: DbConnection, ctx: EventContext.Reducer<*>) { + conn.reducers.handleReducerEvent(ctx) + } +} + +/** + * Typed table accessors for this module's tables. + */ +val DbConnection.db: RemoteTables + get() = moduleTables as RemoteTables + +/** + * Typed reducer call functions for this module's reducers. + */ +val DbConnection.reducers: RemoteReducers + get() = moduleReducers as RemoteReducers + +/** + * Typed procedure call functions for this module's procedures. + */ +val DbConnection.procedures: RemoteProcedures + get() = moduleProcedures as RemoteProcedures + +/** + * Typed table accessors for this module's tables. + */ +val DbConnectionView.db: RemoteTables + get() = moduleTables as RemoteTables + +/** + * Typed reducer call functions for this module's reducers. + */ +val DbConnectionView.reducers: RemoteReducers + get() = moduleReducers as RemoteReducers + +/** + * Typed procedure call functions for this module's procedures. + */ +val DbConnectionView.procedures: RemoteProcedures + get() = moduleProcedures as RemoteProcedures + +/** + * Typed table accessors available directly on event context. + */ +val EventContext.db: RemoteTables + get() = connection.db + +/** + * Typed reducer call functions available directly on event context. + */ +val EventContext.reducers: RemoteReducers + get() = connection.reducers + +/** + * Typed procedure call functions available directly on event context. + */ +val EventContext.procedures: RemoteProcedures + get() = connection.procedures + +/** + * Registers this module's tables with the connection builder. + * Call this on the builder to enable typed [db], [reducers], and [procedures] accessors. + * + * Example: + * ```kotlin + * val conn = DbConnection.Builder() + * .withUri("ws://localhost:3000") + * .withDatabaseName("my_module") + * .withModuleBindings() + * .build() + * ``` + */ +@OptIn(InternalSpacetimeApi::class) +fun DbConnection.Builder.withModuleBindings(): DbConnection.Builder { + return withModule(RemoteModule) +} + +/** + * Type-safe query builder for this module's tables. + * Supports WHERE predicates and semi-joins. + */ +class QueryBuilder { + fun accounts(): Table = Table("accounts", AccountsCols("accounts"), AccountsIxCols()) +} + +/** + * Add a type-safe table query to this subscription. + * + * Example: + * ```kotlin + * conn.subscriptionBuilder() + * .addQuery { qb -> qb.player() } + * .addQuery { qb -> qb.player().where { c -> c.health.gt(50) } } + * .subscribe() + * ``` + */ +fun SubscriptionBuilder.addQuery(build: (QueryBuilder) -> Query<*>): SubscriptionBuilder { + return addQuery(build(QueryBuilder()).toSql()) +} + +/** + * Subscribe to all persistent tables in this module. + * Event tables are excluded because the server does not support subscribing to them. + */ +fun SubscriptionBuilder.subscribeToAllTables(): com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.SubscriptionHandle { + val qb = QueryBuilder() + addQuery(qb.accounts().toSql()) + return subscribe() +} diff --git a/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/RemoteProcedures.kt b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/RemoteProcedures.kt new file mode 100644 index 00000000000..0af78138eef --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/RemoteProcedures.kt @@ -0,0 +1,17 @@ +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ModuleProcedures + +/** Generated procedure call methods and callback registration. */ +@OptIn(InternalSpacetimeApi::class) +class RemoteProcedures internal constructor( + private val conn: DbConnection, +) : ModuleProcedures { +} diff --git a/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/RemoteReducers.kt b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/RemoteReducers.kt new file mode 100644 index 00000000000..9017dbde6cc --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/RemoteReducers.kt @@ -0,0 +1,88 @@ +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.CallbackList +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.EventContext +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ModuleReducers +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.Status + +/** Generated reducer call methods and callback registration. */ +@OptIn(InternalSpacetimeApi::class) +class RemoteReducers internal constructor( + private val conn: DbConnection, +) : ModuleReducers { + fun seed(n: UInt, initialBalance: Long, callback: ((EventContext.Reducer) -> Unit)? = null) { + val args = SeedArgs(n, initialBalance) + conn.callReducer(SeedReducer.REDUCER_NAME, args.encode(), args, callback) + } + + fun transfer(from: UInt, to: UInt, amount: Long, callback: ((EventContext.Reducer) -> Unit)? = null) { + val args = TransferArgs(from, to, amount) + conn.callReducer(TransferReducer.REDUCER_NAME, args.encode(), args, callback) + } + + private val onSeedCallbacks = CallbackList<(EventContext.Reducer, UInt, Long) -> Unit>() + + fun onSeed(cb: (EventContext.Reducer, UInt, Long) -> Unit) { + onSeedCallbacks.add(cb) + } + + fun removeOnSeed(cb: (EventContext.Reducer, UInt, Long) -> Unit) { + onSeedCallbacks.remove(cb) + } + + private val onTransferCallbacks = CallbackList<(EventContext.Reducer, UInt, UInt, Long) -> Unit>() + + fun onTransfer(cb: (EventContext.Reducer, UInt, UInt, Long) -> Unit) { + onTransferCallbacks.add(cb) + } + + fun removeOnTransfer(cb: (EventContext.Reducer, UInt, UInt, Long) -> Unit) { + onTransferCallbacks.remove(cb) + } + + private val onUnhandledReducerErrorCallbacks = CallbackList<(EventContext.Reducer<*>) -> Unit>() + + /** Register a callback for reducer errors with no specific handler. */ + fun onUnhandledReducerError(cb: (EventContext.Reducer<*>) -> Unit) { + onUnhandledReducerErrorCallbacks.add(cb) + } + + fun removeOnUnhandledReducerError(cb: (EventContext.Reducer<*>) -> Unit) { + onUnhandledReducerErrorCallbacks.remove(cb) + } + + internal fun handleReducerEvent(ctx: EventContext.Reducer<*>) { + when (ctx.reducerName) { + SeedReducer.REDUCER_NAME -> { + if (onSeedCallbacks.isNotEmpty()) { + @Suppress("UNCHECKED_CAST") + val typedCtx = ctx as EventContext.Reducer + onSeedCallbacks.forEach { it(typedCtx, typedCtx.args.n, typedCtx.args.initialBalance) } + } else if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + TransferReducer.REDUCER_NAME -> { + if (onTransferCallbacks.isNotEmpty()) { + @Suppress("UNCHECKED_CAST") + val typedCtx = ctx as EventContext.Reducer + onTransferCallbacks.forEach { it(typedCtx, typedCtx.args.from, typedCtx.args.to, typedCtx.args.amount) } + } else if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + else -> { + if (ctx.status is Status.Failed) { + onUnhandledReducerErrorCallbacks.forEach { it(ctx) } + } + } + } + } +} diff --git a/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/RemoteTables.kt b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/RemoteTables.kt new file mode 100644 index 00000000000..5e64595bcc0 --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/RemoteTables.kt @@ -0,0 +1,27 @@ +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ClientCache +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.DbConnection +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.InternalSpacetimeApi +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.ModuleTables + +/** Generated table accessors for all tables in this module. */ +@OptIn(InternalSpacetimeApi::class) +class RemoteTables internal constructor( + private val conn: DbConnection, + private val clientCache: ClientCache, +) : ModuleTables { + val accounts: AccountsTableHandle by lazy { + @Suppress("UNCHECKED_CAST") + val cache = clientCache.getOrCreateTable(AccountsTableHandle.TABLE_NAME) { + AccountsTableHandle.createTableCache() + } + AccountsTableHandle(conn, cache) + } + +} diff --git a/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/SeedReducer.kt b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/SeedReducer.kt new file mode 100644 index 00000000000..4a38945d52e --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/SeedReducer.kt @@ -0,0 +1,37 @@ +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Arguments for the `seed` reducer. */ +data class SeedArgs( + val n: UInt, + val initialBalance: Long +) { + /** Encodes these arguments to BSATN. */ + fun encode(): ByteArray { + val writer = BsatnWriter() + writer.writeU32(n) + writer.writeI64(initialBalance) + return writer.toByteArray() + } + + companion object { + /** Decodes [SeedArgs] from BSATN. */ + fun decode(reader: BsatnReader): SeedArgs { + val n = reader.readU32() + val initialBalance = reader.readI64() + return SeedArgs(n, initialBalance) + } + } +} + +/** Constants for the `seed` reducer. */ +object SeedReducer { + const val REDUCER_NAME = "seed" +} diff --git a/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/TransferReducer.kt b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/TransferReducer.kt new file mode 100644 index 00000000000..a7b6e8157b6 --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/TransferReducer.kt @@ -0,0 +1,40 @@ +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Arguments for the `transfer` reducer. */ +data class TransferArgs( + val from: UInt, + val to: UInt, + val amount: Long +) { + /** Encodes these arguments to BSATN. */ + fun encode(): ByteArray { + val writer = BsatnWriter() + writer.writeU32(from) + writer.writeU32(to) + writer.writeI64(amount) + return writer.toByteArray() + } + + companion object { + /** Decodes [TransferArgs] from BSATN. */ + fun decode(reader: BsatnReader): TransferArgs { + val from = reader.readU32() + val to = reader.readU32() + val amount = reader.readI64() + return TransferArgs(from, to, amount) + } + } +} + +/** Constants for the `transfer` reducer. */ +object TransferReducer { + const val REDUCER_NAME = "transfer" +} diff --git a/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/Types.kt b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/Types.kt new file mode 100644 index 00000000000..fb57742bc31 --- /dev/null +++ b/templates/keynote-2/spacetimedb-kotlin-client/src/main/kotlin/module_bindings/Types.kt @@ -0,0 +1,31 @@ +// THIS FILE IS AUTOMATICALLY GENERATED BY SPACETIMEDB. EDITS TO THIS FILE +// WILL NOT BE SAVED. MODIFY TABLES IN YOUR MODULE SOURCE CODE INSTEAD. + +@file:Suppress("UNUSED", "SpellCheckingInspection") + +package module_bindings + +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnReader +import com.clockworklabs.spacetimedb_kotlin_sdk.shared_client.bsatn.BsatnWriter + +/** Data type `Accounts` from the module schema. */ +data class Accounts( + val id: UInt, + val balance: Long +) { + /** Encodes this value to BSATN. */ + fun encode(writer: BsatnWriter) { + writer.writeU32(id) + writer.writeI64(balance) + } + + companion object { + /** Decodes a [Accounts] from BSATN. */ + fun decode(reader: BsatnReader): Accounts { + val id = reader.readU32() + val balance = reader.readI64() + return Accounts(id, balance) + } + } +} +