diff --git a/integration/copy_fk_lookup.sql b/integration/copy_fk_lookup.sql new file mode 100644 index 000000000..9ee486422 --- /dev/null +++ b/integration/copy_fk_lookup.sql @@ -0,0 +1,52 @@ +-- Test COPY with FK lookup for sharding +-- Run with: psql -h 127.0.0.1 -p 6432 -U pgdog -d pgdog_sharded -f copy_fk_lookup.sql + +\echo '=== Setup: Creating FK tables ===' + +DROP TABLE IF EXISTS public.copy_orders, public.copy_users; + +CREATE TABLE public.copy_users ( + id BIGINT PRIMARY KEY, + customer_id BIGINT NOT NULL +); + +CREATE TABLE public.copy_orders ( + id BIGINT PRIMARY KEY, + user_id BIGINT REFERENCES public.copy_users(id) +); + +\echo '=== Inserting users with sharding key ===' + +INSERT INTO public.copy_users (id, customer_id) +SELECT i, i * 100 + (i % 17) +FROM generate_series(1, 100) AS i; + +\echo '=== COPY orders via FK lookup (text format) ===' + +COPY public.copy_orders (id, user_id) FROM STDIN; +10 1 +20 2 +30 3 +40 4 +50 5 +60 6 +70 7 +80 8 +90 9 +100 10 +\. + +\echo '=== Verify: Count orders ===' +SELECT COUNT(*) AS order_count FROM public.copy_orders; + +\echo '=== Verify: Sample data ===' +SELECT o.id AS order_id, o.user_id, u.customer_id +FROM public.copy_orders o +JOIN public.copy_users u ON o.user_id = u.id +ORDER BY o.id +LIMIT 10; + +\echo '=== Cleanup ===' +DROP TABLE IF EXISTS public.copy_orders, public.copy_users; + +\echo '=== DONE ===' diff --git a/integration/pgdog.toml b/integration/pgdog.toml index 041d6bb2e..c2644b160 100644 --- a/integration/pgdog.toml +++ b/integration/pgdog.toml @@ -22,7 +22,7 @@ tls_certificate = "integration/tls/cert.pem" tls_private_key = "integration/tls/key.pem" query_parser_engine = "pg_query_raw" system_catalogs = "omnisharded_sticky" -reload_schema_on_ddl = false +reload_schema_on_ddl = true [memory] net_buffer = 8096 diff --git a/integration/rust/tests/integration/copy.rs b/integration/rust/tests/integration/copy.rs new file mode 100644 index 000000000..976c40d04 --- /dev/null +++ b/integration/rust/tests/integration/copy.rs @@ -0,0 +1,235 @@ +//! Integration tests for COPY with FK lookup sharding. +//! +//! Tests that pgdog can resolve sharding keys via FK relationships during COPY. +//! The child table (copy_orders) does NOT have the sharding key (customer_id) directly; +//! pgdog must look it up via the FK to copy_users. + +use rust::setup::{admin_sqlx, connections_sqlx}; +use sqlx::postgres::PgPoolCopyExt; +use sqlx::{Executor, Pool, Postgres}; + +async fn setup_fk_tables(pool: &Pool, admin: &Pool) { + // Drop tables if they exist (in public schema) + pool.execute("DROP TABLE IF EXISTS public.copy_orders, public.copy_users") + .await + .unwrap(); + + // Create users table with sharding key (customer_id) in public schema + pool.execute( + "CREATE TABLE public.copy_users ( + id BIGINT PRIMARY KEY, + customer_id BIGINT NOT NULL + )", + ) + .await + .unwrap(); + + // Create orders table with FK to users - no customer_id column! + // pgdog must look up customer_id via the FK + pool.execute( + "CREATE TABLE public.copy_orders ( + id BIGINT PRIMARY KEY, + user_id BIGINT REFERENCES public.copy_users(id) + )", + ) + .await + .unwrap(); + + // Reload schema so pgdog picks up the new tables and FK relationships + admin.execute("RELOAD").await.unwrap(); + + // Wait for schema reload to complete (happens asynchronously) + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; +} + +async fn cleanup_fk_tables(pool: &Pool, admin: &Pool) { + pool.execute("DROP TABLE IF EXISTS public.copy_orders, public.copy_users") + .await + .ok(); + admin.execute("RELOAD").await.unwrap(); +} + +#[tokio::test] +async fn test_copy_fk_lookup_text() { + let admin = admin_sqlx().await; + let mut pools = connections_sqlx().await; + let pool = pools.swap_remove(1); // pgdog_sharded + + setup_fk_tables(&pool, &admin).await; + + // Insert users with varying customer_id (sharding key) + for i in 1i64..=100 { + sqlx::query("INSERT INTO public.copy_users (id, customer_id) VALUES ($1, $2)") + .bind(i) + .bind(i * 100 + i % 17) + .execute(&pool) + .await + .unwrap(); + } + + // Use COPY to insert orders referencing users + // Only pass id and user_id - pgdog should look up customer_id via FK + let copy_data: String = (1i64..=100) + .map(|i| format!("{}\t{}\n", i * 10, i)) // order_id, user_id (FK) + .collect(); + + let mut copy_in = pool + .copy_in_raw("COPY public.copy_orders (id, user_id) FROM STDIN") + .await + .unwrap(); + copy_in.send(copy_data.as_bytes()).await.unwrap(); + copy_in.finish().await.unwrap(); + + // Verify all orders were inserted + let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM public.copy_orders") + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!(count.0, 100, "expected 100 orders"); + + // Verify orders data is correct (no JOINs - pgdog doesn't support sharded JOINs) + let orders: Vec<(i64, i64)> = + sqlx::query_as("SELECT id, user_id FROM public.copy_orders ORDER BY id LIMIT 10") + .fetch_all(&pool) + .await + .unwrap(); + + assert_eq!(orders.len(), 10); + for (i, (order_id, user_id)) in orders.into_iter().enumerate() { + let expected_order_id = ((i + 1) * 10) as i64; + let expected_user_id = (i + 1) as i64; + assert_eq!(order_id, expected_order_id); + assert_eq!(user_id, expected_user_id); + } + + cleanup_fk_tables(&pool, &admin).await; +} + +#[tokio::test] +async fn test_copy_fk_lookup_binary() { + let admin = admin_sqlx().await; + let mut pools = connections_sqlx().await; + let pool = pools.swap_remove(1); // pgdog_sharded + + setup_fk_tables(&pool, &admin).await; + + // Insert users with varying customer_id (sharding key) + for i in 1i64..=50 { + sqlx::query("INSERT INTO public.copy_users (id, customer_id) VALUES ($1, $2)") + .bind(i) + .bind(i * 100 + i % 13) + .execute(&pool) + .await + .unwrap(); + } + + // Use binary COPY to insert orders + // Only pass id and user_id - pgdog should look up customer_id via FK + let mut binary_data = Vec::new(); + + // Binary COPY header: PGCOPY\n\377\r\n\0 + binary_data.extend_from_slice(b"PGCOPY\n\xff\r\n\0"); + // Flags (4 bytes) + header extension length (4 bytes) + binary_data.extend_from_slice(&[0u8; 8]); + + // Insert 50 rows + for i in 1i64..=50 { + // Number of columns (2 bytes) + binary_data.extend_from_slice(&2i16.to_be_bytes()); + // Column 1: order id (8 bytes for BIGINT) + binary_data.extend_from_slice(&8i32.to_be_bytes()); // length + binary_data.extend_from_slice(&(i * 10).to_be_bytes()); // value + // Column 2: user_id FK (8 bytes for BIGINT) + binary_data.extend_from_slice(&8i32.to_be_bytes()); // length + binary_data.extend_from_slice(&i.to_be_bytes()); // value + } + + // Trailer: -1 (2 bytes) + binary_data.extend_from_slice(&(-1i16).to_be_bytes()); + + let mut copy_in = pool + .copy_in_raw("COPY public.copy_orders (id, user_id) FROM STDIN WITH (FORMAT binary)") + .await + .unwrap(); + copy_in.send(binary_data.as_slice()).await.unwrap(); + let result = copy_in.finish().await; + if let Err(e) = &result { + eprintln!("Binary COPY failed: {:?}", e); + } + let rows_copied = result.unwrap(); + assert_eq!(rows_copied, 50, "expected 50 rows copied"); + + // Small delay before verify + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + + // Verify all orders were inserted + let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM public.copy_orders") + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!(count.0, 50, "expected 50 orders"); + + // Verify data integrity + let rows: Vec<(i64, i64)> = + sqlx::query_as("SELECT id, user_id FROM public.copy_orders ORDER BY id") + .fetch_all(&pool) + .await + .unwrap(); + + for (i, (order_id, user_id)) in rows.into_iter().enumerate() { + let expected_order_id = ((i + 1) * 10) as i64; + let expected_user_id = (i + 1) as i64; + assert_eq!(order_id, expected_order_id); + assert_eq!(user_id, expected_user_id); + } + + cleanup_fk_tables(&pool, &admin).await; +} + +#[tokio::test] +async fn test_copy_direct_sharding_key() { + let admin = admin_sqlx().await; + let mut pools = connections_sqlx().await; + let pool = pools.swap_remove(1); // pgdog_sharded + + // For this test, we COPY directly to copy_users which has the sharding key + pool.execute("DROP TABLE IF EXISTS public.copy_orders, public.copy_users") + .await + .unwrap(); + + pool.execute( + "CREATE TABLE public.copy_users ( + id BIGINT PRIMARY KEY, + customer_id BIGINT NOT NULL + )", + ) + .await + .unwrap(); + + admin.execute("RELOAD").await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + + // Use COPY to insert users directly (sharding by customer_id column) + let copy_data: String = (1i64..=100) + .map(|i| format!("{}\t{}\n", i, i * 100 + i % 7)) + .collect(); + + let mut copy_in = pool + .copy_in_raw("COPY public.copy_users (id, customer_id) FROM STDIN") + .await + .unwrap(); + copy_in.send(copy_data.as_bytes()).await.unwrap(); + copy_in.finish().await.unwrap(); + + // Verify all users were inserted + let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM public.copy_users") + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!(count.0, 100, "expected 100 users"); + + pool.execute("DROP TABLE IF EXISTS public.copy_users") + .await + .unwrap(); + admin.execute("RELOAD").await.unwrap(); +} diff --git a/integration/rust/tests/integration/mod.rs b/integration/rust/tests/integration/mod.rs index a25d94950..8053dfaa0 100644 --- a/integration/rust/tests/integration/mod.rs +++ b/integration/rust/tests/integration/mod.rs @@ -5,6 +5,7 @@ pub mod avg; pub mod ban; pub mod client_ids; pub mod connection_recovery; +pub mod copy; pub mod cross_shard_disabled; pub mod distinct; pub mod explain; diff --git a/pgdog-plugin/src/bindings.rs b/pgdog-plugin/src/bindings.rs index 6f47703df..561d24e5b 100644 --- a/pgdog-plugin/src/bindings.rs +++ b/pgdog-plugin/src/bindings.rs @@ -1,338 +1,213 @@ /* automatically generated by rust-bindgen 0.71.1 */ +pub const _STDINT_H: u32 = 1; +pub const _FEATURES_H: u32 = 1; +pub const _DEFAULT_SOURCE: u32 = 1; +pub const __GLIBC_USE_ISOC2Y: u32 = 0; +pub const __GLIBC_USE_ISOC23: u32 = 0; +pub const __USE_ISOC11: u32 = 1; +pub const __USE_ISOC99: u32 = 1; +pub const __USE_ISOC95: u32 = 1; +pub const __USE_POSIX_IMPLICITLY: u32 = 1; +pub const _POSIX_SOURCE: u32 = 1; +pub const _POSIX_C_SOURCE: u32 = 200809; +pub const __USE_POSIX: u32 = 1; +pub const __USE_POSIX2: u32 = 1; +pub const __USE_POSIX199309: u32 = 1; +pub const __USE_POSIX199506: u32 = 1; +pub const __USE_XOPEN2K: u32 = 1; +pub const __USE_XOPEN2K8: u32 = 1; +pub const _ATFILE_SOURCE: u32 = 1; pub const __WORDSIZE: u32 = 64; -pub const __has_safe_buffers: u32 = 1; -pub const __DARWIN_ONLY_64_BIT_INO_T: u32 = 1; -pub const __DARWIN_ONLY_UNIX_CONFORMANCE: u32 = 1; -pub const __DARWIN_ONLY_VERS_1050: u32 = 1; -pub const __DARWIN_UNIX03: u32 = 1; -pub const __DARWIN_64_BIT_INO_T: u32 = 1; -pub const __DARWIN_VERS_1050: u32 = 1; -pub const __DARWIN_NON_CANCELABLE: u32 = 0; -pub const __DARWIN_SUF_EXTSN: &[u8; 14] = b"$DARWIN_EXTSN\0"; -pub const __DARWIN_C_ANSI: u32 = 4096; -pub const __DARWIN_C_FULL: u32 = 900000; -pub const __DARWIN_C_LEVEL: u32 = 900000; -pub const __STDC_WANT_LIB_EXT1__: u32 = 1; -pub const __DARWIN_NO_LONG_LONG: u32 = 0; -pub const _DARWIN_FEATURE_64_BIT_INODE: u32 = 1; -pub const _DARWIN_FEATURE_ONLY_64_BIT_INODE: u32 = 1; -pub const _DARWIN_FEATURE_ONLY_VERS_1050: u32 = 1; -pub const _DARWIN_FEATURE_ONLY_UNIX_CONFORMANCE: u32 = 1; -pub const _DARWIN_FEATURE_UNIX_CONFORMANCE: u32 = 3; -pub const __has_ptrcheck: u32 = 0; -pub const USE_CLANG_TYPES: u32 = 0; -pub const __PTHREAD_SIZE__: u32 = 8176; -pub const __PTHREAD_ATTR_SIZE__: u32 = 56; -pub const __PTHREAD_MUTEXATTR_SIZE__: u32 = 8; -pub const __PTHREAD_MUTEX_SIZE__: u32 = 56; -pub const __PTHREAD_CONDATTR_SIZE__: u32 = 8; -pub const __PTHREAD_COND_SIZE__: u32 = 40; -pub const __PTHREAD_ONCE_SIZE__: u32 = 8; -pub const __PTHREAD_RWLOCK_SIZE__: u32 = 192; -pub const __PTHREAD_RWLOCKATTR_SIZE__: u32 = 16; -pub const INT8_MAX: u32 = 127; -pub const INT16_MAX: u32 = 32767; -pub const INT32_MAX: u32 = 2147483647; -pub const INT64_MAX: u64 = 9223372036854775807; +pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1; +pub const __SYSCALL_WORDSIZE: u32 = 64; +pub const __TIMESIZE: u32 = 64; +pub const __USE_TIME_BITS64: u32 = 1; +pub const __USE_MISC: u32 = 1; +pub const __USE_ATFILE: u32 = 1; +pub const __USE_FORTIFY_LEVEL: u32 = 0; +pub const __GLIBC_USE_DEPRECATED_GETS: u32 = 0; +pub const __GLIBC_USE_DEPRECATED_SCANF: u32 = 0; +pub const __GLIBC_USE_C23_STRTOL: u32 = 0; +pub const _STDC_PREDEF_H: u32 = 1; +pub const __STDC_IEC_559__: u32 = 1; +pub const __STDC_IEC_60559_BFP__: u32 = 201404; +pub const __STDC_IEC_559_COMPLEX__: u32 = 1; +pub const __STDC_IEC_60559_COMPLEX__: u32 = 201404; +pub const __STDC_ISO_10646__: u32 = 201706; +pub const __GNU_LIBRARY__: u32 = 6; +pub const __GLIBC__: u32 = 2; +pub const __GLIBC_MINOR__: u32 = 42; +pub const _SYS_CDEFS_H: u32 = 1; +pub const __glibc_c99_flexarr_available: u32 = 1; +pub const __LDOUBLE_REDIRECTS_TO_FLOAT128_ABI: u32 = 0; +pub const __HAVE_GENERIC_SELECTION: u32 = 1; +pub const __GLIBC_USE_LIB_EXT2: u32 = 0; +pub const __GLIBC_USE_IEC_60559_BFP_EXT: u32 = 0; +pub const __GLIBC_USE_IEC_60559_BFP_EXT_C23: u32 = 0; +pub const __GLIBC_USE_IEC_60559_EXT: u32 = 0; +pub const __GLIBC_USE_IEC_60559_FUNCS_EXT: u32 = 0; +pub const __GLIBC_USE_IEC_60559_FUNCS_EXT_C23: u32 = 0; +pub const __GLIBC_USE_IEC_60559_TYPES_EXT: u32 = 0; +pub const _BITS_TYPES_H: u32 = 1; +pub const _BITS_TYPESIZES_H: u32 = 1; +pub const __OFF_T_MATCHES_OFF64_T: u32 = 1; +pub const __INO_T_MATCHES_INO64_T: u32 = 1; +pub const __RLIM_T_MATCHES_RLIM64_T: u32 = 1; +pub const __STATFS_MATCHES_STATFS64: u32 = 1; +pub const __KERNEL_OLD_TIMEVAL_MATCHES_TIMEVAL64: u32 = 1; +pub const __FD_SETSIZE: u32 = 1024; +pub const _BITS_TIME64_H: u32 = 1; +pub const _BITS_WCHAR_H: u32 = 1; +pub const _BITS_STDINT_INTN_H: u32 = 1; +pub const _BITS_STDINT_UINTN_H: u32 = 1; +pub const _BITS_STDINT_LEAST_H: u32 = 1; pub const INT8_MIN: i32 = -128; pub const INT16_MIN: i32 = -32768; pub const INT32_MIN: i32 = -2147483648; -pub const INT64_MIN: i64 = -9223372036854775808; +pub const INT8_MAX: u32 = 127; +pub const INT16_MAX: u32 = 32767; +pub const INT32_MAX: u32 = 2147483647; pub const UINT8_MAX: u32 = 255; pub const UINT16_MAX: u32 = 65535; pub const UINT32_MAX: u32 = 4294967295; -pub const UINT64_MAX: i32 = -1; pub const INT_LEAST8_MIN: i32 = -128; pub const INT_LEAST16_MIN: i32 = -32768; pub const INT_LEAST32_MIN: i32 = -2147483648; -pub const INT_LEAST64_MIN: i64 = -9223372036854775808; pub const INT_LEAST8_MAX: u32 = 127; pub const INT_LEAST16_MAX: u32 = 32767; pub const INT_LEAST32_MAX: u32 = 2147483647; -pub const INT_LEAST64_MAX: u64 = 9223372036854775807; pub const UINT_LEAST8_MAX: u32 = 255; pub const UINT_LEAST16_MAX: u32 = 65535; pub const UINT_LEAST32_MAX: u32 = 4294967295; -pub const UINT_LEAST64_MAX: i32 = -1; pub const INT_FAST8_MIN: i32 = -128; -pub const INT_FAST16_MIN: i32 = -32768; -pub const INT_FAST32_MIN: i32 = -2147483648; -pub const INT_FAST64_MIN: i64 = -9223372036854775808; +pub const INT_FAST16_MIN: i64 = -9223372036854775808; +pub const INT_FAST32_MIN: i64 = -9223372036854775808; pub const INT_FAST8_MAX: u32 = 127; -pub const INT_FAST16_MAX: u32 = 32767; -pub const INT_FAST32_MAX: u32 = 2147483647; -pub const INT_FAST64_MAX: u64 = 9223372036854775807; +pub const INT_FAST16_MAX: u64 = 9223372036854775807; +pub const INT_FAST32_MAX: u64 = 9223372036854775807; pub const UINT_FAST8_MAX: u32 = 255; -pub const UINT_FAST16_MAX: u32 = 65535; -pub const UINT_FAST32_MAX: u32 = 4294967295; -pub const UINT_FAST64_MAX: i32 = -1; -pub const INTPTR_MAX: u64 = 9223372036854775807; +pub const UINT_FAST16_MAX: i32 = -1; +pub const UINT_FAST32_MAX: i32 = -1; pub const INTPTR_MIN: i64 = -9223372036854775808; +pub const INTPTR_MAX: u64 = 9223372036854775807; pub const UINTPTR_MAX: i32 = -1; -pub const SIZE_MAX: i32 = -1; -pub const RSIZE_MAX: i32 = -1; -pub const WINT_MIN: i32 = -2147483648; -pub const WINT_MAX: u32 = 2147483647; +pub const PTRDIFF_MIN: i64 = -9223372036854775808; +pub const PTRDIFF_MAX: u64 = 9223372036854775807; pub const SIG_ATOMIC_MIN: i32 = -2147483648; pub const SIG_ATOMIC_MAX: u32 = 2147483647; +pub const SIZE_MAX: i32 = -1; +pub const WINT_MIN: u32 = 0; +pub const WINT_MAX: u32 = 4294967295; pub type wchar_t = ::std::os::raw::c_int; -pub type max_align_t = f64; -pub type int_least8_t = i8; -pub type int_least16_t = i16; -pub type int_least32_t = i32; -pub type int_least64_t = i64; -pub type uint_least8_t = u8; -pub type uint_least16_t = u16; -pub type uint_least32_t = u32; -pub type uint_least64_t = u64; -pub type int_fast8_t = i8; -pub type int_fast16_t = i16; -pub type int_fast32_t = i32; -pub type int_fast64_t = i64; -pub type uint_fast8_t = u8; -pub type uint_fast16_t = u16; -pub type uint_fast32_t = u32; -pub type uint_fast64_t = u64; -pub type __int8_t = ::std::os::raw::c_schar; -pub type __uint8_t = ::std::os::raw::c_uchar; -pub type __int16_t = ::std::os::raw::c_short; -pub type __uint16_t = ::std::os::raw::c_ushort; -pub type __int32_t = ::std::os::raw::c_int; -pub type __uint32_t = ::std::os::raw::c_uint; -pub type __int64_t = ::std::os::raw::c_longlong; -pub type __uint64_t = ::std::os::raw::c_ulonglong; -pub type __darwin_intptr_t = ::std::os::raw::c_long; -pub type __darwin_natural_t = ::std::os::raw::c_uint; -pub type __darwin_ct_rune_t = ::std::os::raw::c_int; -#[repr(C)] -#[derive(Copy, Clone)] -pub union __mbstate_t { - pub __mbstate8: [::std::os::raw::c_char; 128usize], - pub _mbstateL: ::std::os::raw::c_longlong, -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of __mbstate_t"][::std::mem::size_of::<__mbstate_t>() - 128usize]; - ["Alignment of __mbstate_t"][::std::mem::align_of::<__mbstate_t>() - 8usize]; - ["Offset of field: __mbstate_t::__mbstate8"] - [::std::mem::offset_of!(__mbstate_t, __mbstate8) - 0usize]; - ["Offset of field: __mbstate_t::_mbstateL"] - [::std::mem::offset_of!(__mbstate_t, _mbstateL) - 0usize]; -}; -pub type __darwin_mbstate_t = __mbstate_t; -pub type __darwin_ptrdiff_t = ::std::os::raw::c_long; -pub type __darwin_size_t = ::std::os::raw::c_ulong; -pub type __darwin_va_list = __builtin_va_list; -pub type __darwin_wchar_t = ::std::os::raw::c_int; -pub type __darwin_rune_t = __darwin_wchar_t; -pub type __darwin_wint_t = ::std::os::raw::c_int; -pub type __darwin_clock_t = ::std::os::raw::c_ulong; -pub type __darwin_socklen_t = __uint32_t; -pub type __darwin_ssize_t = ::std::os::raw::c_long; -pub type __darwin_time_t = ::std::os::raw::c_long; -pub type __darwin_blkcnt_t = __int64_t; -pub type __darwin_blksize_t = __int32_t; -pub type __darwin_dev_t = __int32_t; -pub type __darwin_fsblkcnt_t = ::std::os::raw::c_uint; -pub type __darwin_fsfilcnt_t = ::std::os::raw::c_uint; -pub type __darwin_gid_t = __uint32_t; -pub type __darwin_id_t = __uint32_t; -pub type __darwin_ino64_t = __uint64_t; -pub type __darwin_ino_t = __darwin_ino64_t; -pub type __darwin_mach_port_name_t = __darwin_natural_t; -pub type __darwin_mach_port_t = __darwin_mach_port_name_t; -pub type __darwin_mode_t = __uint16_t; -pub type __darwin_off_t = __int64_t; -pub type __darwin_pid_t = __int32_t; -pub type __darwin_sigset_t = __uint32_t; -pub type __darwin_suseconds_t = __int32_t; -pub type __darwin_uid_t = __uint32_t; -pub type __darwin_useconds_t = __uint32_t; -pub type __darwin_uuid_t = [::std::os::raw::c_uchar; 16usize]; -pub type __darwin_uuid_string_t = [::std::os::raw::c_char; 37usize]; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct __darwin_pthread_handler_rec { - pub __routine: ::std::option::Option, - pub __arg: *mut ::std::os::raw::c_void, - pub __next: *mut __darwin_pthread_handler_rec, -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of __darwin_pthread_handler_rec"] - [::std::mem::size_of::<__darwin_pthread_handler_rec>() - 24usize]; - ["Alignment of __darwin_pthread_handler_rec"] - [::std::mem::align_of::<__darwin_pthread_handler_rec>() - 8usize]; - ["Offset of field: __darwin_pthread_handler_rec::__routine"] - [::std::mem::offset_of!(__darwin_pthread_handler_rec, __routine) - 0usize]; - ["Offset of field: __darwin_pthread_handler_rec::__arg"] - [::std::mem::offset_of!(__darwin_pthread_handler_rec, __arg) - 8usize]; - ["Offset of field: __darwin_pthread_handler_rec::__next"] - [::std::mem::offset_of!(__darwin_pthread_handler_rec, __next) - 16usize]; -}; #[repr(C)] +#[repr(align(16))] #[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_attr_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 56usize], +pub struct max_align_t { + pub __clang_max_align_nonce1: ::std::os::raw::c_longlong, + pub __bindgen_padding_0: u64, + pub __clang_max_align_nonce2: u128, } #[allow(clippy::unnecessary_operation, clippy::identity_op)] const _: () = { - ["Size of _opaque_pthread_attr_t"][::std::mem::size_of::<_opaque_pthread_attr_t>() - 64usize]; - ["Alignment of _opaque_pthread_attr_t"] - [::std::mem::align_of::<_opaque_pthread_attr_t>() - 8usize]; - ["Offset of field: _opaque_pthread_attr_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_attr_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_attr_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_attr_t, __opaque) - 8usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_cond_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 40usize], -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _opaque_pthread_cond_t"][::std::mem::size_of::<_opaque_pthread_cond_t>() - 48usize]; - ["Alignment of _opaque_pthread_cond_t"] - [::std::mem::align_of::<_opaque_pthread_cond_t>() - 8usize]; - ["Offset of field: _opaque_pthread_cond_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_cond_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_cond_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_cond_t, __opaque) - 8usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_condattr_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 8usize], -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _opaque_pthread_condattr_t"] - [::std::mem::size_of::<_opaque_pthread_condattr_t>() - 16usize]; - ["Alignment of _opaque_pthread_condattr_t"] - [::std::mem::align_of::<_opaque_pthread_condattr_t>() - 8usize]; - ["Offset of field: _opaque_pthread_condattr_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_condattr_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_condattr_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_condattr_t, __opaque) - 8usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_mutex_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 56usize], -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _opaque_pthread_mutex_t"][::std::mem::size_of::<_opaque_pthread_mutex_t>() - 64usize]; - ["Alignment of _opaque_pthread_mutex_t"] - [::std::mem::align_of::<_opaque_pthread_mutex_t>() - 8usize]; - ["Offset of field: _opaque_pthread_mutex_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_mutex_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_mutex_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_mutex_t, __opaque) - 8usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_mutexattr_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 8usize], -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _opaque_pthread_mutexattr_t"] - [::std::mem::size_of::<_opaque_pthread_mutexattr_t>() - 16usize]; - ["Alignment of _opaque_pthread_mutexattr_t"] - [::std::mem::align_of::<_opaque_pthread_mutexattr_t>() - 8usize]; - ["Offset of field: _opaque_pthread_mutexattr_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_mutexattr_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_mutexattr_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_mutexattr_t, __opaque) - 8usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_once_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 8usize], -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _opaque_pthread_once_t"][::std::mem::size_of::<_opaque_pthread_once_t>() - 16usize]; - ["Alignment of _opaque_pthread_once_t"] - [::std::mem::align_of::<_opaque_pthread_once_t>() - 8usize]; - ["Offset of field: _opaque_pthread_once_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_once_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_once_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_once_t, __opaque) - 8usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_rwlock_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 192usize], -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _opaque_pthread_rwlock_t"] - [::std::mem::size_of::<_opaque_pthread_rwlock_t>() - 200usize]; - ["Alignment of _opaque_pthread_rwlock_t"] - [::std::mem::align_of::<_opaque_pthread_rwlock_t>() - 8usize]; - ["Offset of field: _opaque_pthread_rwlock_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_rwlock_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_rwlock_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_rwlock_t, __opaque) - 8usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_rwlockattr_t { - pub __sig: ::std::os::raw::c_long, - pub __opaque: [::std::os::raw::c_char; 16usize], -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _opaque_pthread_rwlockattr_t"] - [::std::mem::size_of::<_opaque_pthread_rwlockattr_t>() - 24usize]; - ["Alignment of _opaque_pthread_rwlockattr_t"] - [::std::mem::align_of::<_opaque_pthread_rwlockattr_t>() - 8usize]; - ["Offset of field: _opaque_pthread_rwlockattr_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_rwlockattr_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_rwlockattr_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_rwlockattr_t, __opaque) - 8usize]; + ["Size of max_align_t"][::std::mem::size_of::() - 32usize]; + ["Alignment of max_align_t"][::std::mem::align_of::() - 16usize]; + ["Offset of field: max_align_t::__clang_max_align_nonce1"] + [::std::mem::offset_of!(max_align_t, __clang_max_align_nonce1) - 0usize]; + ["Offset of field: max_align_t::__clang_max_align_nonce2"] + [::std::mem::offset_of!(max_align_t, __clang_max_align_nonce2) - 16usize]; }; +pub type __u_char = ::std::os::raw::c_uchar; +pub type __u_short = ::std::os::raw::c_ushort; +pub type __u_int = ::std::os::raw::c_uint; +pub type __u_long = ::std::os::raw::c_ulong; +pub type __int8_t = ::std::os::raw::c_schar; +pub type __uint8_t = ::std::os::raw::c_uchar; +pub type __int16_t = ::std::os::raw::c_short; +pub type __uint16_t = ::std::os::raw::c_ushort; +pub type __int32_t = ::std::os::raw::c_int; +pub type __uint32_t = ::std::os::raw::c_uint; +pub type __int64_t = ::std::os::raw::c_long; +pub type __uint64_t = ::std::os::raw::c_ulong; +pub type __int_least8_t = __int8_t; +pub type __uint_least8_t = __uint8_t; +pub type __int_least16_t = __int16_t; +pub type __uint_least16_t = __uint16_t; +pub type __int_least32_t = __int32_t; +pub type __uint_least32_t = __uint32_t; +pub type __int_least64_t = __int64_t; +pub type __uint_least64_t = __uint64_t; +pub type __quad_t = ::std::os::raw::c_long; +pub type __u_quad_t = ::std::os::raw::c_ulong; +pub type __intmax_t = ::std::os::raw::c_long; +pub type __uintmax_t = ::std::os::raw::c_ulong; +pub type __dev_t = ::std::os::raw::c_ulong; +pub type __uid_t = ::std::os::raw::c_uint; +pub type __gid_t = ::std::os::raw::c_uint; +pub type __ino_t = ::std::os::raw::c_ulong; +pub type __ino64_t = ::std::os::raw::c_ulong; +pub type __mode_t = ::std::os::raw::c_uint; +pub type __nlink_t = ::std::os::raw::c_ulong; +pub type __off_t = ::std::os::raw::c_long; +pub type __off64_t = ::std::os::raw::c_long; +pub type __pid_t = ::std::os::raw::c_int; #[repr(C)] #[derive(Debug, Copy, Clone)] -pub struct _opaque_pthread_t { - pub __sig: ::std::os::raw::c_long, - pub __cleanup_stack: *mut __darwin_pthread_handler_rec, - pub __opaque: [::std::os::raw::c_char; 8176usize], +pub struct __fsid_t { + pub __val: [::std::os::raw::c_int; 2usize], } #[allow(clippy::unnecessary_operation, clippy::identity_op)] const _: () = { - ["Size of _opaque_pthread_t"][::std::mem::size_of::<_opaque_pthread_t>() - 8192usize]; - ["Alignment of _opaque_pthread_t"][::std::mem::align_of::<_opaque_pthread_t>() - 8usize]; - ["Offset of field: _opaque_pthread_t::__sig"] - [::std::mem::offset_of!(_opaque_pthread_t, __sig) - 0usize]; - ["Offset of field: _opaque_pthread_t::__cleanup_stack"] - [::std::mem::offset_of!(_opaque_pthread_t, __cleanup_stack) - 8usize]; - ["Offset of field: _opaque_pthread_t::__opaque"] - [::std::mem::offset_of!(_opaque_pthread_t, __opaque) - 16usize]; + ["Size of __fsid_t"][::std::mem::size_of::<__fsid_t>() - 8usize]; + ["Alignment of __fsid_t"][::std::mem::align_of::<__fsid_t>() - 4usize]; + ["Offset of field: __fsid_t::__val"][::std::mem::offset_of!(__fsid_t, __val) - 0usize]; }; -pub type __darwin_pthread_attr_t = _opaque_pthread_attr_t; -pub type __darwin_pthread_cond_t = _opaque_pthread_cond_t; -pub type __darwin_pthread_condattr_t = _opaque_pthread_condattr_t; -pub type __darwin_pthread_key_t = ::std::os::raw::c_ulong; -pub type __darwin_pthread_mutex_t = _opaque_pthread_mutex_t; -pub type __darwin_pthread_mutexattr_t = _opaque_pthread_mutexattr_t; -pub type __darwin_pthread_once_t = _opaque_pthread_once_t; -pub type __darwin_pthread_rwlock_t = _opaque_pthread_rwlock_t; -pub type __darwin_pthread_rwlockattr_t = _opaque_pthread_rwlockattr_t; -pub type __darwin_pthread_t = *mut _opaque_pthread_t; -pub type intmax_t = ::std::os::raw::c_long; -pub type uintmax_t = ::std::os::raw::c_ulong; +pub type __clock_t = ::std::os::raw::c_long; +pub type __rlim_t = ::std::os::raw::c_ulong; +pub type __rlim64_t = ::std::os::raw::c_ulong; +pub type __id_t = ::std::os::raw::c_uint; +pub type __time_t = ::std::os::raw::c_long; +pub type __useconds_t = ::std::os::raw::c_uint; +pub type __suseconds_t = ::std::os::raw::c_long; +pub type __suseconds64_t = ::std::os::raw::c_long; +pub type __daddr_t = ::std::os::raw::c_int; +pub type __key_t = ::std::os::raw::c_int; +pub type __clockid_t = ::std::os::raw::c_int; +pub type __timer_t = *mut ::std::os::raw::c_void; +pub type __blksize_t = ::std::os::raw::c_long; +pub type __blkcnt_t = ::std::os::raw::c_long; +pub type __blkcnt64_t = ::std::os::raw::c_long; +pub type __fsblkcnt_t = ::std::os::raw::c_ulong; +pub type __fsblkcnt64_t = ::std::os::raw::c_ulong; +pub type __fsfilcnt_t = ::std::os::raw::c_ulong; +pub type __fsfilcnt64_t = ::std::os::raw::c_ulong; +pub type __fsword_t = ::std::os::raw::c_long; +pub type __ssize_t = ::std::os::raw::c_long; +pub type __syscall_slong_t = ::std::os::raw::c_long; +pub type __syscall_ulong_t = ::std::os::raw::c_ulong; +pub type __loff_t = __off64_t; +pub type __caddr_t = *mut ::std::os::raw::c_char; +pub type __intptr_t = ::std::os::raw::c_long; +pub type __socklen_t = ::std::os::raw::c_uint; +pub type __sig_atomic_t = ::std::os::raw::c_int; +pub type int_least8_t = __int_least8_t; +pub type int_least16_t = __int_least16_t; +pub type int_least32_t = __int_least32_t; +pub type int_least64_t = __int_least64_t; +pub type uint_least8_t = __uint_least8_t; +pub type uint_least16_t = __uint_least16_t; +pub type uint_least32_t = __uint_least32_t; +pub type uint_least64_t = __uint_least64_t; +pub type int_fast8_t = ::std::os::raw::c_schar; +pub type int_fast16_t = ::std::os::raw::c_long; +pub type int_fast32_t = ::std::os::raw::c_long; +pub type int_fast64_t = ::std::os::raw::c_long; +pub type uint_fast8_t = ::std::os::raw::c_uchar; +pub type uint_fast16_t = ::std::os::raw::c_ulong; +pub type uint_fast32_t = ::std::os::raw::c_ulong; +pub type uint_fast64_t = ::std::os::raw::c_ulong; +pub type intmax_t = __intmax_t; +pub type uintmax_t = __uintmax_t; #[doc = " Wrapper around Rust's [`&str`], without allocating memory, unlike [`std::ffi::CString`].\n The caller must use it as a Rust string. This is not a C-string."] #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -449,4 +324,3 @@ const _: () = { ["Offset of field: PdRoute::shard"][::std::mem::offset_of!(PdRoute, shard) - 0usize]; ["Offset of field: PdRoute::read_write"][::std::mem::offset_of!(PdRoute, read_write) - 8usize]; }; -pub type __builtin_va_list = *mut ::std::os::raw::c_char; diff --git a/pgdog-stats/src/schema.rs b/pgdog-stats/src/schema.rs index a92af097c..1c29f2f96 100644 --- a/pgdog-stats/src/schema.rs +++ b/pgdog-stats/src/schema.rs @@ -65,6 +65,9 @@ pub struct Relation { pub oid: i32, /// Columns indexed by name, ordered by ordinal position. pub columns: IndexMap, + /// Whether this relation is sharded. + #[serde(default)] + pub is_sharded: bool, } impl Hash for Relation { @@ -81,6 +84,7 @@ impl Hash for Relation { key.hash(state); value.hash(state); } + self.is_sharded.hash(state); } } diff --git a/pgdog/src/backend/error.rs b/pgdog/src/backend/error.rs index f24651985..5f9f7e9f1 100644 --- a/pgdog/src/backend/error.rs +++ b/pgdog/src/backend/error.rs @@ -134,6 +134,12 @@ pub enum Error { #[error("unsupported aggregation {function}: {reason}")] UnsupportedAggregation { function: String, reason: String }, + + #[error("no foreign key path found from table \"{table}\" to any sharded table")] + NoForeignKeyPath { table: String }, + + #[error("sharding error: {0}")] + Sharding(String), } impl From for Error { diff --git a/pgdog/src/backend/pool/cluster.rs b/pgdog/src/backend/pool/cluster.rs index 33c0798a9..ba9c8386e 100644 --- a/pgdog/src/backend/pool/cluster.rs +++ b/pgdog/src/backend/pool/cluster.rs @@ -542,14 +542,17 @@ impl Cluster { .swap(true, Ordering::SeqCst); if self.load_schema() && !already_started { + let sharding_schema = self.sharding_schema(); + for shard in self.shards() { let identifier = self.identifier(); let readiness = self.readiness.clone(); let shard = shard.clone(); let shards = self.shards.len(); + let sharding = sharding_schema.clone(); spawn(async move { - if let Err(err) = shard.update_schema().await { + if let Err(err) = shard.update_schema(&sharding).await { error!("error loading schema for shard {}: {}", shard.number(), err); } @@ -632,6 +635,7 @@ mod test { backend::{ pool::{Address, Config, PoolConfig, ShardConfig}, replication::ShardedSchemas, + schema::test_helpers::prelude::*, Shard, ShardedTables, }, config::{ @@ -644,6 +648,8 @@ mod test { impl Cluster { pub fn new_test(config: &ConfigAndUsers) -> Self { + use crate::backend::replication::ShardedSchemas as SS; + let identifier = Arc::new(DatabaseUser { user: "pgdog".into(), database: "pgdog".into(), @@ -657,10 +663,68 @@ mod test { config: Config::default(), }]; + let sharded_tables = ShardedTables::new( + vec![ShardedTable { + database: "pgdog".into(), + name: Some("sharded".into()), + column: "id".into(), + primary: true, + centroids: vec![], + data_type: DataType::Bigint, + centroids_path: None, + centroid_probes: 1, + hasher: Hasher::Postgres, + ..Default::default() + }], + vec![ + OmnishardedTable { + name: "sharded_omni".into(), + sticky_routing: false, + }, + OmnishardedTable { + name: "sharded_omni_sticky".into(), + sticky_routing: true, + }, + ], + config.config.general.omnisharded_sticky, + config.config.general.system_catalogs, + ); + + let sharded_schemas = ShardedSchemas::new(vec![ + ShardedSchema { + database: "pgdog".into(), + name: Some("shard_0".into()), + shard: 0, + ..Default::default() + }, + ShardedSchema { + database: "pgdog".into(), + name: Some("shard_1".into()), + shard: 1, + ..Default::default() + }, + ]); + + // Create test schema with sharded tables + let mut db_schema = schema() + .relation(table("sharded").oid(1).column(pk("id"))) + .relation(table("sharded_omni").oid(2).column(pk("id"))) + .relation(table("sharded_omni_sticky").oid(3).column(pk("id"))) + .build(); + + // Compute sharded joins to set is_sharded flags + let sharding_schema = super::ShardingSchema { + shards: 2, + tables: sharded_tables.clone(), + schemas: SS::new(vec![]), + ..Default::default() + }; + db_schema.computed_sharded_joins(&sharding_schema); + let shards = (0..2) .into_iter() .map(|number| { - Shard::new(ShardConfig { + let shard = Shard::new(ShardConfig { number, primary, replicas, @@ -668,51 +732,16 @@ mod test { rw_split: ReadWriteSplit::IncludePrimary, identifier: identifier.clone(), lsn_check_interval: Duration::MAX, - }) + }); + // Set the test schema on each shard + shard.set_test_schema(db_schema.clone()); + shard }) .collect::>(); Cluster { - sharded_tables: ShardedTables::new( - vec![ShardedTable { - database: "pgdog".into(), - name: Some("sharded".into()), - column: "id".into(), - primary: true, - centroids: vec![], - data_type: DataType::Bigint, - centroids_path: None, - centroid_probes: 1, - hasher: Hasher::Postgres, - ..Default::default() - }], - vec![ - OmnishardedTable { - name: "sharded_omni".into(), - sticky_routing: false, - }, - OmnishardedTable { - name: "sharded_omni_sticky".into(), - sticky_routing: true, - }, - ], - config.config.general.omnisharded_sticky, - config.config.general.system_catalogs, - ), - sharded_schemas: ShardedSchemas::new(vec![ - ShardedSchema { - database: "pgdog".into(), - name: Some("shard_0".into()), - shard: 0, - ..Default::default() - }, - ShardedSchema { - database: "pgdog".into(), - name: Some("shard_1".into()), - shard: 1, - ..Default::default() - }, - ]), + sharded_tables, + sharded_schemas, shards, identifier, prepared_statements: config.config.general.prepared_statements, @@ -736,6 +765,63 @@ mod test { pub fn set_read_write_strategy(&mut self, rw_strategy: ReadWriteStrategy) { self.rw_strategy = rw_strategy; } + + /// Create a test cluster with custom schema and sharding tables. + /// Unlike new_test(), this doesn't preset the schema on shards, + /// allowing the caller to set their own schema. + pub fn new_test_with_sharding( + sharded_tables: ShardedTables, + db_schema: crate::backend::schema::Schema, + ) -> Self { + let config = ConfigAndUsers::default(); + let identifier = Arc::new(DatabaseUser { + user: "pgdog".into(), + database: "pgdog".into(), + }); + let primary = &Some(PoolConfig { + address: Address::new_test(), + config: Config::default(), + }); + let replicas = &[PoolConfig { + address: Address::new_test(), + config: Config::default(), + }]; + + let sharded_schemas = ShardedSchemas::new(vec![]); + + let shards = (0..2) + .into_iter() + .map(|number| { + let shard = Shard::new(ShardConfig { + number, + primary, + replicas, + lb_strategy: LoadBalancingStrategy::Random, + rw_split: ReadWriteSplit::IncludePrimary, + identifier: identifier.clone(), + lsn_check_interval: Duration::MAX, + }); + // Set the custom schema on each shard + shard.set_test_schema(db_schema.clone()); + shard + }) + .collect::>(); + + Cluster { + sharded_tables, + sharded_schemas, + shards, + identifier, + prepared_statements: config.config.general.prepared_statements, + dry_run: config.config.general.dry_run, + expanded_explain: config.config.general.expanded_explain, + query_parser: config.config.general.query_parser, + rewrite: config.config.rewrite.clone(), + two_phase_commit: config.config.general.two_phase_commit, + two_phase_commit_auto: config.config.general.two_phase_commit_auto.unwrap_or(false), + ..Default::default() + } + } } #[test] diff --git a/pgdog/src/backend/pool/connection/mod.rs b/pgdog/src/backend/pool/connection/mod.rs index 7ea427941..ec491558d 100644 --- a/pgdog/src/backend/pool/connection/mod.rs +++ b/pgdog/src/backend/pool/connection/mod.rs @@ -296,6 +296,7 @@ impl Connection { if client_request.is_copy() && !streaming { let rows = router .copy_data(client_request) + .await .map_err(|e| Error::Router(e.to_string()))?; if !rows.is_empty() { self.send_copy(rows).await?; diff --git a/pgdog/src/backend/pool/shard/mod.rs b/pgdog/src/backend/pool/shard/mod.rs index fca067eb7..5bb8fc6ed 100644 --- a/pgdog/src/backend/pool/shard/mod.rs +++ b/pgdog/src/backend/pool/shard/mod.rs @@ -126,12 +126,20 @@ impl Shard { } /// Load schema from the shard's primary. - pub async fn update_schema(&self) -> Result<(), crate::backend::Error> { + pub async fn update_schema( + &self, + sharding: &crate::backend::ShardingSchema, + ) -> Result<(), crate::backend::Error> { let mut server = self.primary_or_replica(&Request::default()).await?; - let schema = Schema::load(&mut server).await?; + let mut schema = Schema::load(&mut server).await?; + + // Compute joins for tables that don't have the sharding key directly + schema.computed_sharded_joins(sharding); + info!( - "loaded schema for {} tables on shard {} [{}]", + "loaded schema for {} tables ({} joins) on shard {} [{}]", schema.tables().len(), + schema.joins_count(), self.number(), server.addr() ); @@ -287,6 +295,14 @@ impl ShardInner { } } +#[cfg(test)] +impl Shard { + /// Set the schema for testing purposes. + pub fn set_test_schema(&self, schema: Schema) { + let _ = self.inner.schema.set(schema); + } +} + #[cfg(test)] mod test { use std::collections::BTreeSet; diff --git a/pgdog/src/backend/replication/logical/subscriber/copy.rs b/pgdog/src/backend/replication/logical/subscriber/copy.rs index 28741b152..fa8c75b7d 100644 --- a/pgdog/src/backend/replication/logical/subscriber/copy.rs +++ b/pgdog/src/backend/replication/logical/subscriber/copy.rs @@ -170,7 +170,7 @@ impl CopySubscriber { } async fn flush(&mut self) -> Result<(), Error> { - let result = self.copy.shard(&self.buffer)?; + let result = self.copy.shard(&self.buffer).await?; self.buffer.clear(); for row in &result { diff --git a/pgdog/src/backend/schema/fk_lookup.rs b/pgdog/src/backend/schema/fk_lookup.rs new file mode 100644 index 000000000..b61c44e8a --- /dev/null +++ b/pgdog/src/backend/schema/fk_lookup.rs @@ -0,0 +1,432 @@ +//! Foreign key lookup for sharding. +//! +//! Used during COPY when a table is sharded but doesn't have the sharding key directly. +//! Queries the cluster via FK joins to find the sharding key value. + +use std::collections::HashMap; +use tracing::debug; + +use crate::{ + backend::{ + pool::{Guard, Request}, + Cluster, Error, + }, + config::ShardedTable, + frontend::router::{ + parser::Shard, + sharding::{ContextBuilder, Data as ShardingData}, + }, + net::{ + messages::{Bind, DataRow, Format, FromBytes, Message, Parameter, Protocol, ToBytes}, + Execute, Parse, ProtocolMessage, Sync, + }, +}; + +use super::Join; + +/// FK lookup state for a single shard connection. +#[derive(Debug)] +struct ShardConnection { + server: Guard, + prepared: bool, +} + +/// Foreign key lookup for resolving sharding keys via FK relationships. +/// +/// Keeps replica connections to all shards open and uses prepared statements +/// to minimize query overhead (Parse once, then Bind/Execute/Sync for each lookup). +#[derive(Debug)] +pub struct FkLookup { + /// The join query and sharding info. + join: Join, + /// Unique name for the prepared statement. + prepared_name: String, + /// Number of shards. + num_shards: usize, + /// Connections to each shard (lazily initialized). + connections: HashMap, + /// The cluster for connections. + cluster: Cluster, +} + +impl FkLookup { + /// Create a new FK lookup from a Join. + pub fn new(join: Join, cluster: Cluster) -> Self { + let prepared_name = format!("__pgdog_fk_{}", uuid::Uuid::new_v4()); + let num_shards = cluster.shards().len(); + Self { + join, + prepared_name, + num_shards, + connections: HashMap::new(), + cluster, + } + } + + /// Get or create a connection to a shard replica. + async fn ensure_connection(&mut self, shard: usize) -> Result<(), Error> { + if !self.connections.contains_key(&shard) { + let server = self.cluster.replica(shard, &Request::default()).await?; + self.connections.insert( + shard, + ShardConnection { + server, + prepared: false, + }, + ); + } + Ok(()) + } + + /// Prepare the statement on a shard if not already prepared. + async fn ensure_prepared(&mut self, shard: usize) -> Result<(), Error> { + self.ensure_connection(shard).await?; + + // Check if already prepared + if self + .connections + .get(&shard) + .map(|c| c.prepared) + .unwrap_or(false) + { + return Ok(()); + } + + // Clone values needed for the Parse message before getting mutable borrow + let prepared_name = self.prepared_name.clone(); + let query = self.join.query.clone(); + + let parse = Parse::named(&prepared_name, &query); + let messages = vec![ProtocolMessage::from(parse), Sync.into()]; + + let conn = self.connections.get_mut(&shard).unwrap(); + conn.server.send(&messages.into()).await?; + + // Read responses until ReadyForQuery + loop { + let msg: Message = conn.server.read().await?; + match msg.code() { + 'Z' => break, + 'E' => { + let err = crate::net::messages::ErrorResponse::from_bytes(msg.to_bytes()?)?; + return Err(Error::ConnectionError(Box::new(err))); + } + _ => continue, + } + } + + conn.prepared = true; + Ok(()) + } + + /// Look up the sharding key for a given primary key value. + /// + /// Queries all shards in parallel to find the row, + /// then applies the sharding function to determine the target shard. + /// Accepts both text and binary formats via ShardingData. + pub async fn lookup(&mut self, pk_value: ShardingData<'_>) -> Result { + // Ensure all connections are prepared first + for shard in 0..self.num_shards { + self.ensure_prepared(shard).await?; + } + + let (param, format) = pk_value.parameter_with_format(); + + // Take connections out to allow parallel querying + let mut connections = std::mem::take(&mut self.connections); + let prepared_name = self.prepared_name.clone(); + + // Create futures for all shards + let futures: Vec<_> = (0..self.num_shards) + .filter_map(|shard| { + connections.remove(&shard).map(|conn| { + Self::query_connection( + conn, + prepared_name.clone(), + param.clone(), + format, + shard, + ) + }) + }) + .collect(); + + // Run all queries in parallel + let results = futures::future::join_all(futures).await; + + // Put connections back and find result + let mut sharding_key = None; + for (shard, conn, result) in results { + self.connections.insert(shard, conn); + if let Ok(Some(key)) = result { + sharding_key = Some(key); + } + } + + // Apply sharding + let shard = if let Some(key) = sharding_key { + self.apply_sharding(&key)? + } else { + Shard::All + }; + + debug!( + "sharding key via foreign key lookup resolved to shard={}", + shard + ); + + Ok(shard) + } + + /// Query a single connection for the sharding key. + async fn query_connection( + mut conn: ShardConnection, + prepared_name: String, + param: Parameter, + format: Format, + shard: usize, + ) -> (usize, ShardConnection, Result, Error>) { + let result = Self::do_query(&mut conn, &prepared_name, param, format).await; + (shard, conn, result) + } + + /// Execute the actual query on a connection. + async fn do_query( + conn: &mut ShardConnection, + prepared_name: &str, + param: Parameter, + format: Format, + ) -> Result, Error> { + // Request text format (0) for results so we can parse as string + let bind = Bind::new_params_codes_results(prepared_name, &[param], &[format], &[0]); + let execute = Execute::new(); + + let messages = vec![ProtocolMessage::from(bind), execute.into(), Sync.into()]; + conn.server.send(&messages.into()).await?; + + let mut result: Option = None; + + // Read responses until ReadyForQuery + loop { + let msg: Message = conn.server.read().await?; + match msg.code() { + 'D' => { + let data_row = DataRow::from_bytes(msg.to_bytes()?)?; + if let Some(value) = data_row.get_text(0) { + result = Some(value); + } + } + 'Z' => break, + 'E' => { + let err = crate::net::messages::ErrorResponse::from_bytes(msg.to_bytes()?)?; + return Err(Error::ConnectionError(Box::new(err))); + } + _ => continue, + } + } + + Ok(result) + } + + /// Apply the sharding function to the key value. + fn apply_sharding(&self, key: &str) -> Result { + let table = &self.join.sharded_table; + let ctx = ContextBuilder::new(table) + .data(key) + .shards(self.num_shards) + .build() + .map_err(|e| Error::Sharding(e.to_string()))?; + + ctx.apply().map_err(|e| Error::Sharding(e.to_string())) + } + + /// Get a reference to the join configuration. + pub fn join(&self) -> &Join { + &self.join + } + + /// Get the sharded table configuration. + pub fn sharded_table(&self) -> &ShardedTable { + &self.join.sharded_table + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::backend::schema::test_helpers::prelude::*; + use crate::config::config; + + #[test] + fn test_fk_lookup_new() { + let sharding_schema = sharding().sharded_table("users", "user_id").build(); + + let db_schema = schema() + .relation( + table("users") + .oid(1001) + .column(pk("id")) + .column(col("user_id")), + ) + .relation( + table("orders") + .oid(1002) + .column(pk("id")) + .column(fk("user_id", "users", "id")), + ) + .build(); + + let relation = db_schema.inner.get("public", "orders").unwrap(); + let join = db_schema + .construct_join(relation, &sharding_schema) + .unwrap(); + + let lookup = FkLookup::new(join.clone(), Cluster::default()); + + assert!(lookup.prepared_name.starts_with("__pgdog_fk_")); + assert_eq!(lookup.join().sharding_column.name, "user_id"); + assert_eq!(lookup.sharded_table().column, "user_id"); + } + + #[tokio::test] + async fn test_fk_lookup_query() { + crate::logger(); + + let cluster = Cluster::new_test(&config()); + cluster.launch(); + + // Use a single connection for setup (both shards point to same DB in tests) + let mut server = cluster.primary(0, &Request::default()).await.unwrap(); + + // Drop and recreate tables with FK relationship + server + .execute("DROP TABLE IF EXISTS pgdog.fk_orders, pgdog.fk_users") + .await + .unwrap(); + + server + .execute( + "CREATE TABLE pgdog.fk_users ( + id BIGINT PRIMARY KEY, + user_id BIGINT NOT NULL + )", + ) + .await + .unwrap(); + + server + .execute( + "CREATE TABLE pgdog.fk_orders ( + id BIGINT PRIMARY KEY, + user_id BIGINT REFERENCES pgdog.fk_users(id) + )", + ) + .await + .unwrap(); + + // Insert 1000 users and orders + for i in 1i64..=1000 { + server + .execute(&format!( + "INSERT INTO pgdog.fk_users (id, user_id) VALUES ({}, {})", + i, + i * 100 + i % 17 // Add some variation to user_id + )) + .await + .unwrap(); + + server + .execute(&format!( + "INSERT INTO pgdog.fk_orders (id, user_id) VALUES ({}, {})", + i * 10, // order id + i // references user id + )) + .await + .unwrap(); + } + + // Build schema with FK relationships + let sharding_schema = sharding() + .sharded_table("fk_users", "user_id") + .database("pgdog") + .build(); + + let db_schema = schema() + .relation( + table("fk_users") + .schema("pgdog") + .oid(1001) + .column(pk("id")) + .column(col("user_id")), + ) + .relation( + table("fk_orders") + .schema("pgdog") + .oid(1002) + .column(pk("id")) + .column(ColumnBuilder::new("user_id").foreign_key("pgdog", "fk_users", "id")), + ) + .build(); + + let relation = db_schema.inner.get("pgdog", "fk_orders").unwrap(); + let join = db_schema + .construct_join(relation, &sharding_schema) + .unwrap(); + + // Query directly looks up sharding key from fk_users + assert!(join.query.contains("user_id")); + assert!(join.query.contains("fk_users")); + + let mut lookup = FkLookup::new(join, cluster.clone()); + + // Test all 1000 FK values - pass the FK value (user_id which references fk_users.id) + for i in 1i64..=1000 { + // The FK value is `i` (references fk_users.id) + let fk_value = i; + let fk_value_str = fk_value.to_string(); + + let text_result = lookup + .lookup(ShardingData::Text(&fk_value_str)) + .await + .unwrap(); + let binary_result = lookup + .lookup(ShardingData::Binary(&fk_value.to_be_bytes())) + .await + .unwrap(); + + assert!( + matches!(text_result, Shard::Direct(_)), + "FK value {} should route to a shard", + fk_value + ); + assert_eq!( + text_result, binary_result, + "text and binary should match for FK value {}", + fk_value + ); + } + + // Look up non-existent FK values -> Shard::All + for non_existent in [99999i64, 123456, 999999] { + let text_result = lookup + .lookup(ShardingData::Text(&non_existent.to_string())) + .await + .unwrap(); + let binary_result = lookup + .lookup(ShardingData::Binary(&non_existent.to_be_bytes())) + .await + .unwrap(); + + assert_eq!(text_result, Shard::All); + assert_eq!(text_result, binary_result); + } + + // Clean up + cluster + .execute("DROP TABLE IF EXISTS pgdog.fk_orders, pgdog.fk_users") + .await + .unwrap(); + + cluster.shutdown(); + } +} diff --git a/pgdog/src/backend/schema/join.rs b/pgdog/src/backend/schema/join.rs new file mode 100644 index 000000000..b554df3a9 --- /dev/null +++ b/pgdog/src/backend/schema/join.rs @@ -0,0 +1,444 @@ +use std::collections::{HashSet, VecDeque}; + +use serde::{Deserialize, Serialize}; + +use crate::config::ShardedTable; +use crate::{backend::ShardingSchema, frontend::router::parser::OwnedColumn}; + +use super::Error; +use super::Schema; +use super::StatsRelation; + +/// A step in the join path from start table to the sharding key. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct JoinStep { + /// Source column (includes table/schema info via the FK column). + pub from: OwnedColumn, + /// Target column (the referenced column in the target table). + pub to: OwnedColumn, +} + +/// Result of constructing a join path to find a sharding key. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Join { + /// The path of joins from start table to the table containing the sharding key. + pub path: Vec, + /// The sharding key column in the final table. + pub sharding_column: OwnedColumn, + /// The SQL query to fetch the sharding key value using the FK value. + pub query: String, + /// The sharded table configuration. + pub sharded_table: ShardedTable, +} + +impl Join { + /// Get the FK column name for FK lookup. + /// Returns the first FK column in the path that should be used for lookup. + pub fn fk_column(&self) -> Option<&str> { + self.path.first().map(|step| step.from.name.as_str()) + } +} + +impl Schema { + /// Construct a SELECT ... JOIN ... JOIN ... + /// query that fetches the value of the sharding key + /// from a table that has a foreign key relationship to "start" table. + /// + /// The relationship can span multiple tables. + /// Uses BFS to find the shortest path. + pub fn construct_join( + &self, + start: &StatsRelation, + sharding: &ShardingSchema, + ) -> Result { + let start_schema = start.schema(); + let start_table = &start.name; + + // Find primary key of start table + let start_pk = Self::find_primary_key_in_relation(start); + + // Check if start table itself has the sharding key + if let Some((sharding_col, sharded_table)) = + Self::find_sharding_column_in_relation(start, sharding) + { + let query = self.build_direct_query( + start_schema, + start_table, + &sharding_col, + start_pk.as_deref(), + ); + return Ok(Join { + path: vec![], + sharding_column: OwnedColumn { + name: sharding_col, + table: Some(start_table.clone()), + schema: Some(start_schema.to_string()), + }, + query, + sharded_table, + }); + } + + // BFS to find shortest path + let mut visited: HashSet<(&str, &str)> = HashSet::new(); + let mut queue: VecDeque<(&str, &str, Vec)> = VecDeque::new(); + + queue.push_back((start_schema, start_table, vec![])); + visited.insert((start_schema, start_table)); + + while let Some((current_schema, current_table, path)) = queue.pop_front() { + let relation = match self.inner.get(current_schema, current_table) { + Some(r) => r, + None => continue, + }; + + // Check each column for foreign keys + for column in relation.columns.values() { + for fk in &column.foreign_keys { + let target_schema = &fk.schema; + let target_table = &fk.table; + let target_column = &fk.column; + + // Check if we've already visited this table + let key = (target_schema.as_str(), target_table.as_str()); + if visited.contains(&key) { + continue; + } + + // Look up target relation + let target_relation = match self.inner.get(target_schema, target_table) { + Some(r) => r, + None => continue, + }; + + // Build the new path + let mut new_path = path.clone(); + new_path.push(JoinStep { + from: OwnedColumn { + name: column.column_name.clone(), + table: Some(current_table.to_string()), + schema: Some(current_schema.to_string()), + }, + to: OwnedColumn { + name: target_column.clone(), + table: Some(target_table.clone()), + schema: Some(target_schema.clone()), + }, + }); + + // Check if target table has the sharding key + if let Some((sharding_col, sharded_table)) = + Self::find_sharding_column_in_relation(target_relation, sharding) + { + // Build query that directly looks up sharding key via FK value + let query = Self::build_fk_lookup_query(&new_path, &sharding_col); + return Ok(Join { + path: new_path, + sharding_column: OwnedColumn { + name: sharding_col, + table: Some(target_table.clone()), + schema: Some(target_schema.clone()), + }, + query, + sharded_table, + }); + } + + // Add to queue + visited.insert((target_relation.schema(), &target_relation.name)); + queue.push_back((target_relation.schema(), &target_relation.name, new_path)); + } + } + } + + Err(Error::NoForeignKeyPath { + table: format!("{}.{}", start_schema, start_table), + }) + } + + /// Find the primary key column of a relation. + fn find_primary_key_in_relation(relation: &StatsRelation) -> Option { + relation + .columns + .values() + .find(|col| col.is_primary_key) + .map(|col| col.column_name.clone()) + } + + /// Check if a relation has a sharding key column. + /// Returns the column name and the matching ShardedTable config. + fn find_sharding_column_in_relation( + relation: &StatsRelation, + sharding: &ShardingSchema, + ) -> Option<(String, ShardedTable)> { + for sharded_table in sharding.tables.tables() { + // Match by schema if specified + if let Some(ref schema) = sharded_table.schema { + if schema != relation.schema() { + continue; + } + } + + // Match by table name if specified + if let Some(ref name) = sharded_table.name { + if name != &relation.name { + continue; + } + } + + // Check if the table has the sharding column + if relation.has_column(&sharded_table.column) { + return Some((sharded_table.column.clone(), sharded_table.clone())); + } + } + + None + } + + /// Build a simple SELECT query for direct access. + fn build_direct_query( + &self, + schema: &str, + table: &str, + column: &str, + primary_key: Option<&str>, + ) -> String { + let mut query = format!( + "SELECT \"{}\".\"{}\".\"{column}\" FROM \"{}\".\"{}\"", + schema, table, schema, table + ); + + if let Some(pk) = primary_key { + query.push_str(&format!( + " WHERE \"{}\".\"{}\".\"{}\" = $1", + schema, table, pk + )); + } + + query + } + + /// Build a query for FK lookup. + /// Takes FK value as $1 and queries parent table(s) to get sharding key. + fn build_fk_lookup_query(path: &[JoinStep], sharding_col: &str) -> String { + if path.is_empty() { + return String::new(); + } + + // For single-hop: SELECT sharding_col FROM parent WHERE pk = $1 + if path.len() == 1 { + let step = &path[0]; + let to_schema = step.to.schema.as_deref().unwrap_or("public"); + let to_table = step.to.table.as_deref().unwrap_or(""); + return format!( + "SELECT \"{}\".\"{}\".\"{}\" FROM \"{}\".\"{}\" WHERE \"{}\".\"{}\".\"{}\" = $1", + to_schema, + to_table, + sharding_col, + to_schema, + to_table, + to_schema, + to_table, + step.to.name + ); + } + + // For multi-hop: start from first target and join to final table + let first_step = &path[0]; + let first_schema = first_step.to.schema.as_deref().unwrap_or("public"); + let first_table = first_step.to.table.as_deref().unwrap_or(""); + + let last_step = path.last().unwrap(); + let target_schema = last_step.to.schema.as_deref().unwrap_or("public"); + let target_table = last_step.to.table.as_deref().unwrap_or(""); + + let mut query = format!( + "SELECT \"{}\".\"{}\".\"{}\" FROM \"{}\".\"{}\"", + target_schema, target_table, sharding_col, first_schema, first_table + ); + + // Build joins from path[1..] (skip first step, we're starting from its target) + for step in &path[1..] { + let from_schema = step.from.schema.as_deref().unwrap_or("public"); + let from_table = step.from.table.as_deref().unwrap_or(""); + let to_schema = step.to.schema.as_deref().unwrap_or("public"); + let to_table = step.to.table.as_deref().unwrap_or(""); + + query.push_str(&format!( + " JOIN \"{}\".\"{}\" ON \"{}\".\"{}\".\"{}\" = \"{}\".\"{}\".\"{}\"", + to_schema, + to_table, + from_schema, + from_table, + step.from.name, + to_schema, + to_table, + step.to.name + )); + } + + // WHERE uses the first target's referenced column + query.push_str(&format!( + " WHERE \"{}\".\"{}\".\"{}\" = $1", + first_schema, first_table, first_step.to.name + )); + + query + } +} + +#[cfg(test)] +mod test { + use super::super::test_helpers::prelude::*; + + /// Build the standard test schema: + /// users (id PK, user_id - sharding key) + /// orders (id PK, user_id FK -> users.id) + /// order_items (id PK, order_id FK -> orders.id) + fn build_test_schema() -> ( + super::super::Schema, + crate::backend::pool::cluster::ShardingSchema, + ) { + let db_schema = schema() + .relation( + table("users") + .oid(1001) + .column(pk("id")) + .column(col("user_id")), + ) + .relation( + table("orders") + .oid(1002) + .column(pk("id")) + .column(fk("user_id", "users", "id")), + ) + .relation( + table("order_items") + .oid(1003) + .column(pk("id")) + .column(fk("order_id", "orders", "id")), + ) + .build(); + + let sharding_schema = sharding().sharded_table("users", "user_id").build(); + + (db_schema, sharding_schema) + } + + #[test] + fn test_construct_join_direct_table_has_sharding_key() { + let (db_schema, sharding_schema) = build_test_schema(); + + let relation = db_schema.inner.get("public", "users").unwrap(); + let join = db_schema + .construct_join(relation, &sharding_schema) + .unwrap(); + + assert!(join.path.is_empty()); + assert_eq!(join.sharding_column.name, "user_id"); + assert_eq!( + join.query, + r#"SELECT "public"."users"."user_id" FROM "public"."users" WHERE "public"."users"."id" = $1"# + ); + } + + #[test] + fn test_construct_join_one_hop() { + let (db_schema, sharding_schema) = build_test_schema(); + + let relation = db_schema.inner.get("public", "orders").unwrap(); + let join = db_schema + .construct_join(relation, &sharding_schema) + .unwrap(); + + assert_eq!(join.path.len(), 1); + assert_eq!(join.path[0].from.table, Some("orders".into())); + assert_eq!(join.path[0].from.name, "user_id"); + assert_eq!(join.path[0].to.table, Some("users".into())); + assert_eq!(join.path[0].to.name, "id"); + assert_eq!(join.sharding_column.name, "user_id"); + // Query directly looks up sharding key from parent table using FK value + assert_eq!( + join.query, + r#"SELECT "public"."users"."user_id" FROM "public"."users" WHERE "public"."users"."id" = $1"# + ); + } + + #[test] + fn test_construct_join_two_hops() { + let (db_schema, sharding_schema) = build_test_schema(); + + let relation = db_schema.inner.get("public", "order_items").unwrap(); + let join = db_schema + .construct_join(relation, &sharding_schema) + .unwrap(); + + assert_eq!(join.path.len(), 2); + assert_eq!(join.path[0].from.table, Some("order_items".into())); + assert_eq!(join.path[0].from.name, "order_id"); + assert_eq!(join.path[0].to.table, Some("orders".into())); + assert_eq!(join.path[1].from.table, Some("orders".into())); + assert_eq!(join.path[1].to.table, Some("users".into())); + assert_eq!(join.sharding_column.name, "user_id"); + // Query starts from first target (orders) and joins to find sharding key + assert_eq!( + join.query, + r#"SELECT "public"."users"."user_id" FROM "public"."orders" JOIN "public"."users" ON "public"."orders"."user_id" = "public"."users"."id" WHERE "public"."orders"."id" = $1"# + ); + } + + #[test] + fn test_construct_join_no_path() { + let sharding_schema = sharding().sharded_table("users", "user_id").build(); + + let db_schema = schema() + .relation(table("isolated").oid(9999).column(pk("id"))) + .build(); + + let relation = db_schema.inner.get("public", "isolated").unwrap(); + let result = db_schema.construct_join(relation, &sharding_schema); + assert!(result.is_err()); + } + + #[test] + fn test_compute_joins_and_get_join() { + let (mut db_schema, sharding_schema) = build_test_schema(); + + db_schema.computed_sharded_joins(&sharding_schema); + + // users has sharding key directly - no join stored but is_sharded = true + let users_relation = db_schema.inner.get("public", "users").unwrap(); + assert!(db_schema.get_sharded_join(users_relation).is_none()); + assert!(users_relation.is_sharded); + + // orders needs a join to get to users - is_sharded = true + let orders_relation = db_schema.inner.get("public", "orders").unwrap(); + let orders_join = db_schema.get_sharded_join(orders_relation); + assert!(orders_join.is_some()); + assert_eq!(orders_join.unwrap().path.len(), 1); + assert!(orders_relation.is_sharded); + + // order_items needs two joins to get to users - is_sharded = true + let order_items_relation = db_schema.inner.get("public", "order_items").unwrap(); + let order_items_join = db_schema.get_sharded_join(order_items_relation); + assert!(order_items_join.is_some()); + assert_eq!(order_items_join.unwrap().path.len(), 2); + assert!(order_items_relation.is_sharded); + + assert_eq!(db_schema.joins_count(), 2); + } + + #[test] + fn test_is_sharded_flag_not_set_for_isolated_table() { + let sharding_schema = sharding().sharded_table("users", "user_id").build(); + + let mut db_schema = schema() + .relation(table("isolated").oid(9999).column(pk("id"))) + .build(); + + db_schema.computed_sharded_joins(&sharding_schema); + + let relation = db_schema.inner.get("public", "isolated").unwrap(); + assert!(!relation.is_sharded); + } +} diff --git a/pgdog/src/backend/schema/mod.rs b/pgdog/src/backend/schema/mod.rs index 009bb63e0..8d1115613 100644 --- a/pgdog/src/backend/schema/mod.rs +++ b/pgdog/src/backend/schema/mod.rs @@ -1,28 +1,47 @@ //! Schema operations. pub mod columns; +pub mod fk_lookup; +pub mod join; pub mod relation; pub mod sync; +#[cfg(test)] +pub mod test_helpers; + +pub use fk_lookup::FkLookup; +use fnv::FnvHashMap; +pub use join::Join; pub use pgdog_stats::{ Relation as StatsRelation, Relations as StatsRelations, Schema as StatsSchema, SchemaInner, }; use serde::{Deserialize, Serialize}; +use std::hash::Hash; use std::ops::DerefMut; use std::{collections::HashMap, ops::Deref}; use tracing::debug; pub use relation::Relation; -use super::{pool::Request, Cluster, Error, Server}; +use super::{pool::Request, Cluster, Error, Server, ShardingSchema}; use crate::frontend::router::parser::Table; use crate::net::parameter::ParameterValue; static SETUP: &str = include_str!("setup.sql"); /// Load schema from database. -#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Hash)] +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] pub struct Schema { inner: StatsSchema, + /// Precomputed joins for tables that don't have the sharding key directly. + /// Key is the relation OID. + #[serde(skip)] + joins: FnvHashMap, +} + +impl Hash for Schema { + fn hash(&self, state: &mut H) { + self.inner.hash(state); + } } impl Deref for Schema { @@ -66,6 +85,7 @@ impl Schema { Ok(Self { inner: StatsSchema::new(inner), + joins: FnvHashMap::default(), }) } @@ -91,6 +111,7 @@ impl Schema { search_path, relations: nested, }), + joins: FnvHashMap::default(), } } @@ -184,6 +205,21 @@ impl Schema { None } + /// Get this table join to a sharded table. + pub fn table_sharded_join( + &self, + table: Table<'_>, + user: &str, + search_path: Option<&'_ ParameterValue>, + ) -> Option<&Join> { + let relation = self.table(table, user, search_path); + if let Some(relation) = relation { + self.get_sharded_join(relation) + } else { + None + } + } + fn resolve_search_path<'a>( &'a self, user: &'a str, @@ -223,6 +259,68 @@ impl Schema { pub fn search_path(&self) -> &[String] { &self.inner.search_path } + + /// Compute and store joins for all tables that don't have the sharding key directly. + /// Also sets the `is_sharded` flag on each relation based on whether it can + /// participate in sharded routing (either has sharding key directly or via FK path). + pub fn computed_sharded_joins(&mut self, sharding: &ShardingSchema) { + use std::collections::HashSet; + + self.joins.clear(); + + // Collect (oid, schema, name) to avoid borrow issues and enable O(1) lookup + let tables_info: Vec<(i32, String, String)> = self + .tables() + .iter() + .map(|r| (r.oid, r.schema().to_string(), r.name.clone())) + .collect(); + + // Track which tables can participate in sharding + let mut sharded_oids: HashSet = HashSet::new(); + + for (oid, schema_name, table_name) in tables_info { + // O(1) HashMap lookup instead of O(N) linear search + let relation = match self.inner.get(&schema_name, &table_name) { + Some(r) => r, + None => continue, + }; + + // Try to construct a join - if successful, table can participate in sharding + if let Ok(join) = self.construct_join(relation, sharding) { + sharded_oids.insert(oid); + + // Only store join path if table doesn't have sharding key directly + if !join.path.is_empty() { + self.joins.insert(oid, join); + } + } + } + + // Clone relations and set is_sharded flag + let mut new_relations = self.inner.relations.clone(); + for tables in new_relations.values_mut() { + for relation in tables.values_mut() { + relation.is_sharded = sharded_oids.contains(&relation.oid); + } + } + + // Rebuild inner with updated relations + self.inner = StatsSchema::new(SchemaInner { + search_path: self.inner.search_path.clone(), + relations: new_relations, + }); + } + + /// Lookup a precomputed join for a relation by OID. + /// Returns None if the table has the sharding key directly or no path exists. + pub fn get_sharded_join(&self, relation: &StatsRelation) -> Option<&Join> { + self.joins.get(&relation.oid) + } + + /// Get the number of precomputed joins. + pub fn joins_count(&self) -> usize { + self.joins.len() + } } #[cfg(test)] diff --git a/pgdog/src/backend/schema/relation.rs b/pgdog/src/backend/schema/relation.rs index 903e6490c..1a301a2cd 100644 --- a/pgdog/src/backend/schema/relation.rs +++ b/pgdog/src/backend/schema/relation.rs @@ -60,6 +60,7 @@ impl From for Relation { description: value.get_text(6).unwrap_or_default(), oid: value.get::(7, Format::Text).unwrap_or_default(), columns: IndexMap::new(), + is_sharded: false, }, } } @@ -107,6 +108,7 @@ impl Relation { description: String::new(), oid: 0, columns: columns.into_iter().map(|(k, v)| (k, v.into())).collect(), + is_sharded: false, } .into() } diff --git a/pgdog/src/backend/schema/test_helpers.rs b/pgdog/src/backend/schema/test_helpers.rs new file mode 100644 index 000000000..6f8d34bcf --- /dev/null +++ b/pgdog/src/backend/schema/test_helpers.rs @@ -0,0 +1,314 @@ +//! Test helpers for creating schema structures. +//! +//! These helpers are used across multiple test modules to create +//! schemas with tables, columns, and foreign key relationships. + +use std::collections::HashMap; + +use indexmap::IndexMap; +use pgdog_stats::{Column as StatsColumn, ForeignKey, Relation as StatsRelation}; + +use super::relation::Relation; +use super::Schema; +use crate::backend::pool::cluster::ShardingSchema; +use crate::backend::replication::sharded_tables::ShardedTables; +use crate::config::ShardedTable; +use pgdog_config::SystemCatalogsBehavior; + +/// Builder for creating test columns. +#[derive(Default)] +pub struct ColumnBuilder { + schema: String, + table: String, + name: String, + ordinal: i32, + is_pk: bool, + data_type: String, + foreign_keys: Vec, +} + +impl ColumnBuilder { + pub fn new(name: &str) -> Self { + Self { + name: name.into(), + data_type: "bigint".into(), + ..Default::default() + } + } + + pub fn schema(mut self, schema: &str) -> Self { + self.schema = schema.into(); + self + } + + pub fn table(mut self, table: &str) -> Self { + self.table = table.into(); + self + } + + pub fn ordinal(mut self, ordinal: i32) -> Self { + self.ordinal = ordinal; + self + } + + pub fn primary_key(mut self) -> Self { + self.is_pk = true; + self + } + + pub fn data_type(mut self, data_type: &str) -> Self { + self.data_type = data_type.into(); + self + } + + pub fn foreign_key(mut self, ref_schema: &str, ref_table: &str, ref_column: &str) -> Self { + self.foreign_keys.push(ForeignKey { + schema: ref_schema.into(), + table: ref_table.into(), + column: ref_column.into(), + ..Default::default() + }); + self + } + + pub fn build(self) -> StatsColumn { + StatsColumn { + table_catalog: "test".into(), + table_schema: self.schema, + table_name: self.table, + column_name: self.name, + column_default: String::new(), + is_nullable: !self.is_pk, + data_type: self.data_type, + ordinal_position: self.ordinal, + is_primary_key: self.is_pk, + foreign_keys: self.foreign_keys, + } + } +} + +/// Builder for creating test relations (tables). +pub struct RelationBuilder { + schema: String, + name: String, + oid: i32, + columns: Vec, +} + +impl RelationBuilder { + pub fn new(name: &str) -> Self { + Self { + schema: "public".into(), + name: name.into(), + oid: 0, + columns: vec![], + } + } + + pub fn schema(mut self, schema: &str) -> Self { + self.schema = schema.into(); + self + } + + pub fn oid(mut self, oid: i32) -> Self { + self.oid = oid; + self + } + + pub fn column(mut self, col: ColumnBuilder) -> Self { + let ordinal = self.columns.len() as i32 + 1; + self.columns.push( + col.schema(&self.schema) + .table(&self.name) + .ordinal(ordinal) + .build(), + ); + self + } + + pub fn build(self) -> Relation { + let cols: IndexMap = self + .columns + .into_iter() + .map(|c| (c.column_name.clone(), c)) + .collect(); + StatsRelation { + schema: self.schema, + name: self.name, + type_: "table".into(), + owner: String::new(), + persistence: String::new(), + access_method: String::new(), + description: String::new(), + oid: self.oid, + columns: cols, + is_sharded: false, + } + .into() + } +} + +/// Builder for creating test schemas with multiple tables. +pub struct SchemaBuilder { + search_path: Vec, + relations: HashMap<(String, String), Relation>, +} + +impl Default for SchemaBuilder { + fn default() -> Self { + Self::new() + } +} + +impl SchemaBuilder { + pub fn new() -> Self { + Self { + search_path: vec!["public".into()], + relations: HashMap::new(), + } + } + + pub fn search_path(mut self, path: Vec<&str>) -> Self { + self.search_path = path.into_iter().map(String::from).collect(); + self + } + + pub fn relation(mut self, rel: RelationBuilder) -> Self { + let relation = rel.build(); + let key = (relation.schema.clone(), relation.name.clone()); + self.relations.insert(key, relation); + self + } + + pub fn build(self) -> Schema { + Schema::from_parts(self.search_path, self.relations) + } +} + +/// Builder for creating sharding configuration. +pub struct ShardingBuilder { + shards: usize, + tables: Vec, +} + +impl Default for ShardingBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ShardingBuilder { + pub fn new() -> Self { + Self { + shards: 2, + tables: vec![], + } + } + + pub fn shards(mut self, n: usize) -> Self { + self.shards = n; + self + } + + pub fn sharded_table(mut self, table: &str, column: &str) -> Self { + self.tables.push(ShardedTable { + database: "test".into(), + name: Some(table.into()), + column: column.into(), + ..Default::default() + }); + self + } + + pub fn sharded_column(mut self, column: &str) -> Self { + self.tables.push(ShardedTable { + database: "test".into(), + column: column.into(), + ..Default::default() + }); + self + } + + pub fn database(mut self, database: &str) -> Self { + if let Some(table) = self.tables.last_mut() { + table.database = database.into(); + } + self + } + + pub fn build(self) -> ShardingSchema { + ShardingSchema { + shards: self.shards, + tables: ShardedTables::new( + self.tables, + vec![], + false, + SystemCatalogsBehavior::default(), + ), + ..Default::default() + } + } +} + +/// Convenience functions for common test patterns. +pub mod prelude { + pub use super::{ColumnBuilder, RelationBuilder, SchemaBuilder, ShardingBuilder}; + + /// Create a simple column. + pub fn col(name: &str) -> ColumnBuilder { + ColumnBuilder::new(name) + } + + /// Create a primary key column. + pub fn pk(name: &str) -> ColumnBuilder { + ColumnBuilder::new(name).primary_key() + } + + /// Create a foreign key column. + pub fn fk(name: &str, ref_table: &str, ref_column: &str) -> ColumnBuilder { + ColumnBuilder::new(name).foreign_key("public", ref_table, ref_column) + } + + /// Create a table builder. + pub fn table(name: &str) -> RelationBuilder { + RelationBuilder::new(name) + } + + /// Create a schema builder. + pub fn schema() -> SchemaBuilder { + SchemaBuilder::new() + } + + /// Create a sharding config builder. + pub fn sharding() -> ShardingBuilder { + ShardingBuilder::new() + } +} + +#[cfg(test)] +mod tests { + use super::prelude::*; + + #[test] + fn test_builders() { + let db_schema = schema() + .relation( + table("users") + .oid(1001) + .column(pk("id")) + .column(col("user_id")), + ) + .relation( + table("orders") + .oid(1002) + .column(pk("id")) + .column(fk("user_id", "users", "id")), + ) + .build(); + + let sharding = sharding().sharded_table("users", "user_id").build(); + + assert_eq!(db_schema.tables().len(), 2); + assert_eq!(sharding.shards, 2); + } +} diff --git a/pgdog/src/frontend/client/query_engine/test/copy.rs b/pgdog/src/frontend/client/query_engine/test/copy.rs new file mode 100644 index 000000000..99c89886d --- /dev/null +++ b/pgdog/src/frontend/client/query_engine/test/copy.rs @@ -0,0 +1,281 @@ +//! End-to-end tests for COPY with FK lookup sharding. +//! +//! Tests that the query engine correctly routes COPY rows via FK relationships +//! when the target table doesn't have the sharding key directly. + +use std::ops::Deref; + +use crate::{ + backend::databases::reload_from_existing, + config::{config, load_test_sharded, set, DataType, Hasher, ShardedTable}, + expect_message, + frontend::client::test::TestClient, + net::{CommandComplete, Parameters, Protocol, Query, ReadyForQuery}, +}; + +/// Load test config with FK sharding tables. +fn load_test_fk_sharded() { + load_test_sharded(); + + let mut config = config().deref().clone(); + // Add the FK parent table as a sharded table + config.config.sharded_tables.push(ShardedTable { + database: "pgdog".into(), + name: Some("copy_fk_users".into()), + column: "customer_id".into(), + primary: true, + data_type: DataType::Bigint, + hasher: Hasher::Postgres, + ..Default::default() + }); + set(config).unwrap(); + reload_from_existing().unwrap(); +} + +/// Test COPY with multi-hop FK lookup through the query engine. +/// +/// This test uses 3 tables to verify FK traversal across multiple hops: +/// copy_fk_users (id, customer_id) - has sharding key +/// copy_fk_orders (id, user_id FK -> users) - 1 hop from sharding key +/// copy_fk_order_items (id, order_id FK -> orders) - 2 hops from sharding key +/// +/// When we COPY into order_items, pgdog must traverse: +/// order_items -> orders -> users to find customer_id +#[tokio::test] +async fn test_copy_fk_lookup_end_to_end() { + load_test_fk_sharded(); + let mut client = TestClient::new_sharded(Parameters::default()).await; + + // Drop and create FK tables (3 levels deep) + client + .send(Query::new( + "DROP TABLE IF EXISTS copy_fk_order_items, copy_fk_orders, copy_fk_users", + )) + .await; + client.try_process().await.unwrap(); + client.read_until('Z').await.unwrap(); + + // Level 1: users with sharding key + client + .send(Query::new( + "CREATE TABLE copy_fk_users ( + id BIGINT PRIMARY KEY, + customer_id BIGINT NOT NULL + )", + )) + .await; + client.try_process().await.unwrap(); + client.read_until('Z').await.unwrap(); + + // Level 2: orders with FK to users + client + .send(Query::new( + "CREATE TABLE copy_fk_orders ( + id BIGINT PRIMARY KEY, + user_id BIGINT REFERENCES copy_fk_users(id) + )", + )) + .await; + client.try_process().await.unwrap(); + client.read_until('Z').await.unwrap(); + + // Level 3: order_items with FK to orders (2 hops from sharding key) + client + .send(Query::new( + "CREATE TABLE copy_fk_order_items ( + id BIGINT PRIMARY KEY, + order_id BIGINT REFERENCES copy_fk_orders(id) + )", + )) + .await; + client.try_process().await.unwrap(); + client.read_until('Z').await.unwrap(); + + // Insert users with known customer_id values + for i in 1i64..=10 { + client + .send(Query::new(&format!( + "INSERT INTO copy_fk_users (id, customer_id) VALUES ({}, {})", + i, + i * 100 + ))) + .await; + client.try_process().await.unwrap(); + client.read_until('Z').await.unwrap(); + } + + // Insert orders referencing users + for i in 1i64..=10 { + client + .send(Query::new(&format!( + "INSERT INTO copy_fk_orders (id, user_id) VALUES ({}, {})", + i * 10, + i + ))) + .await; + client.try_process().await.unwrap(); + client.read_until('Z').await.unwrap(); + } + + // Send COPY command for order_items table (2 hops from sharding key) + client + .send(Query::new( + "COPY copy_fk_order_items (id, order_id) FROM STDIN", + )) + .await; + client.try_process().await.unwrap(); + + // Expect CopyInResponse (code 'G') + let copy_in = client.read().await; + assert_eq!(copy_in.code(), 'G', "expected CopyInResponse"); + + // Send COPY data - 10 order_items referencing orders + use crate::net::messages::CopyData; + + let copy_rows: String = (1i64..=10) + .map(|i| format!("{}\t{}\n", i * 100, i * 10)) // item_id, order_id (FK) + .collect(); + + let copy_data = CopyData::new(copy_rows.as_bytes()); + client.send(copy_data).await; + + // Send CopyDone + use crate::net::CopyDone; + client.send(CopyDone).await; + + client.try_process().await.unwrap(); + + // Expect CommandComplete and ReadyForQuery + expect_message!(client.read().await, CommandComplete); + let rfq = expect_message!(client.read().await, ReadyForQuery); + assert_eq!(rfq.status, 'I'); + + // Verify all order_items were inserted + client + .send(Query::new("SELECT COUNT(*) FROM copy_fk_order_items")) + .await; + client.try_process().await.unwrap(); + + let messages = client.read_until('Z').await.unwrap(); + let data_row = messages + .iter() + .find(|m| m.code() == 'D') + .expect("should have DataRow"); + let data_row = crate::net::DataRow::try_from(data_row.clone()).unwrap(); + let count = data_row.get_int(0, true).expect("should have count"); + assert_eq!(count, 10, "expected 10 order_items"); + + // Verify data integrity + client + .send(Query::new( + "SELECT id, order_id FROM copy_fk_order_items ORDER BY id", + )) + .await; + client.try_process().await.unwrap(); + + let messages = client.read_until('Z').await.unwrap(); + let data_rows: Vec<_> = messages + .iter() + .filter(|m| m.code() == 'D') + .map(|m| crate::net::DataRow::try_from(m.clone()).unwrap()) + .collect(); + + assert_eq!(data_rows.len(), 10, "expected 10 data rows"); + for (i, row) in data_rows.iter().enumerate() { + let item_id = row.get_int(0, true).expect("item_id"); + let order_id = row.get_int(1, true).expect("order_id"); + assert_eq!(item_id, ((i + 1) * 100) as i64, "item_id mismatch"); + assert_eq!(order_id, ((i + 1) * 10) as i64, "order_id mismatch"); + } + + // Clean up + client + .send(Query::new( + "DROP TABLE IF EXISTS copy_fk_order_items, copy_fk_orders, copy_fk_users", + )) + .await; + client.try_process().await.unwrap(); + client.read_until('Z').await.unwrap(); +} + +/// Test COPY with direct sharding key (no FK lookup needed). +#[tokio::test] +async fn test_copy_direct_sharding_key_end_to_end() { + load_test_fk_sharded(); + let mut client = TestClient::new_sharded(Parameters::default()).await; + + // Drop and create table with sharding key directly + client + .send(Query::new( + "DROP TABLE IF EXISTS copy_fk_orders, copy_fk_users", + )) + .await; + client.try_process().await.unwrap(); + client.read_until('Z').await.unwrap(); + + client + .send(Query::new( + "CREATE TABLE copy_fk_users ( + id BIGINT PRIMARY KEY, + customer_id BIGINT NOT NULL + )", + )) + .await; + client.try_process().await.unwrap(); + client.read_until('Z').await.unwrap(); + + // Send COPY command for users table (has sharding key directly) + client + .send(Query::new( + "COPY copy_fk_users (id, customer_id) FROM STDIN", + )) + .await; + client.try_process().await.unwrap(); + + // Expect CopyInResponse (code 'G') + let copy_in = client.read().await; + assert_eq!(copy_in.code(), 'G', "expected CopyInResponse"); + + // Send COPY data - 10 users with customer_id (sharding key) + use crate::net::messages::CopyData; + + let copy_rows: String = (1i64..=10) + .map(|i| format!("{}\t{}\n", i, i * 100)) // id, customer_id + .collect(); + + let copy_data = CopyData::new(copy_rows.as_bytes()); + client.send(copy_data).await; + + // Send CopyDone + use crate::net::CopyDone; + client.send(CopyDone).await; + + client.try_process().await.unwrap(); + + // Expect CommandComplete and ReadyForQuery + expect_message!(client.read().await, CommandComplete); + let rfq = expect_message!(client.read().await, ReadyForQuery); + assert_eq!(rfq.status, 'I'); + + // Verify all users were inserted + client + .send(Query::new("SELECT COUNT(*) FROM copy_fk_users")) + .await; + client.try_process().await.unwrap(); + + let messages = client.read_until('Z').await.unwrap(); + let data_row_msg = messages + .into_iter() + .find(|m| m.code() == 'D') + .expect("should have DataRow"); + let data_row = crate::net::DataRow::try_from(data_row_msg).unwrap(); + let count = data_row.get_int(0, true).expect("should have count"); + assert_eq!(count, 10, "expected 10 users"); + + // Clean up + client + .send(Query::new("DROP TABLE IF EXISTS copy_fk_users")) + .await; + client.try_process().await.unwrap(); + client.read_until('Z').await.unwrap(); +} diff --git a/pgdog/src/frontend/client/query_engine/test/mod.rs b/pgdog/src/frontend/client/query_engine/test/mod.rs index de93b592b..e70c2a438 100644 --- a/pgdog/src/frontend/client/query_engine/test/mod.rs +++ b/pgdog/src/frontend/client/query_engine/test/mod.rs @@ -7,6 +7,7 @@ use crate::{ net::{Parameters, Stream}, }; +mod copy; mod omni; pub mod prelude; mod rewrite_extended; diff --git a/pgdog/src/frontend/router/mod.rs b/pgdog/src/frontend/router/mod.rs index 2e37ba3d3..6e03072ed 100644 --- a/pgdog/src/frontend/router/mod.rs +++ b/pgdog/src/frontend/router/mod.rs @@ -73,9 +73,9 @@ impl Router { } /// Parse CopyData messages and shard them. - pub fn copy_data(&mut self, buffer: &ClientRequest) -> Result, Error> { + pub async fn copy_data(&mut self, buffer: &ClientRequest) -> Result, Error> { match self.latest_command { - Command::Copy(ref mut copy) => Ok(copy.shard(&buffer.copy_data()?)?), + Command::Copy(ref mut copy) => Ok(copy.shard(&buffer.copy_data()?).await?), _ => Ok(buffer .copy_data()? .into_iter() diff --git a/pgdog/src/frontend/router/parser/column.rs b/pgdog/src/frontend/router/parser/column.rs index b9f13e1da..c7f20cfee 100644 --- a/pgdog/src/frontend/router/parser/column.rs +++ b/pgdog/src/frontend/router/parser/column.rs @@ -21,7 +21,7 @@ pub struct Column<'a> { } /// Owned version of Column that owns its string data. -#[derive(Debug, Clone, PartialEq, Default)] +#[derive(Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize)] pub struct OwnedColumn { /// Column name. pub name: String, diff --git a/pgdog/src/frontend/router/parser/copy.rs b/pgdog/src/frontend/router/parser/copy.rs index 4b413f144..ee3730201 100644 --- a/pgdog/src/frontend/router/parser/copy.rs +++ b/pgdog/src/frontend/router/parser/copy.rs @@ -1,19 +1,25 @@ //! Parse COPY statement. +use std::sync::Arc; + use pg_query::{protobuf::CopyStmt, NodeEnum}; +use tokio::sync::Mutex; use crate::{ - backend::{Cluster, ShardingSchema}, + backend::{schema::FkLookup, Cluster, ShardingSchema}, config::ShardedTable, frontend::router::{ parser::Shard, - sharding::{ContextBuilder, Tables}, + sharding::{ContextBuilder, Data as ShardingData}, CopyRow, }, net::messages::{CopyData, ToBytes}, }; -use super::{binary::Data, BinaryStream, Column, CsvStream, Error, Table}; +use super::{ + binary::Data as BinaryData, BinaryStream, CsvStream, Error, SchemaLookupContext, + StatementParser, Table, +}; /// Copy information parsed from a COPY statement. #[derive(Debug, Clone)] @@ -70,6 +76,14 @@ pub struct CopyParser { sharded_table: Option, /// The sharding column is in this position in each row. sharded_column: usize, + /// The primary key column position in each row. + primary_key_column: Option, + /// The FK column position for FK lookup (references parent table). + fk_column: Option, + /// Whether this COPY targets a sharded table. + is_sharded: bool, + /// FK lookup for fetching sharding key when table doesn't have it directly. + fk_lookup: Option>>, /// Schema shard. schema_shard: Option, /// String representing NULL values in text/CSV format. @@ -87,6 +101,10 @@ impl Default for CopyParser { sharding_schema: ShardingSchema::default(), sharded_table: None, sharded_column: 0, + primary_key_column: None, + fk_column: None, + is_sharded: false, + fk_lookup: None, schema_shard: None, null_string: "\\N".to_owned(), } @@ -104,30 +122,68 @@ impl CopyParser { let mut format = CopyFormat::Text; let mut null_string = "\\N".to_owned(); - if let Some(ref rel) = stmt.relation { - let mut columns = vec![]; - - for column in &stmt.attlist { - if let Ok(column) = Column::from_string(column) { - columns.push(column); - } - } + let sharding_schema = cluster.sharding_schema(); + let db_schema = cluster.schema(); + if let Some(ref rel) = stmt.relation { let table = Table::from(rel); // The CopyParser is used for replicating // data during data-sync. This will ensure all rows // are sent to the right schema-based shard. - if let Some(schema) = cluster.sharding_schema().schemas.get(table.schema()) { + if let Some(schema) = sharding_schema.schemas.get(table.schema()) { parser.schema_shard = Some(schema.shard().into()); } - if let Some(key) = Tables::new(&cluster.sharding_schema()).key(table, &columns) { - parser.sharded_table = Some(key.table.clone()); - parser.sharded_column = key.position; + let schema_lookup = SchemaLookupContext { + db_schema: &db_schema, + user: "", + search_path: None, + }; + + let mut statement_parser = StatementParser::from_copy(stmt, &sharding_schema) + .with_schema_lookup(schema_lookup); + + parser.is_sharded = statement_parser.is_sharded(&db_schema, "", None); + + if let Some(sharding) = statement_parser.copy_sharding_key() { + if let (Some(position), Some(sharded_table)) = + (sharding.key_position, sharding.key_table) + { + parser.sharded_table = Some(sharded_table.clone()); + parser.sharded_column = position; + } + parser.primary_key_column = sharding.primary_key_position; + } + + // If table is sharded but doesn't have the sharding key directly, + // create an FK lookup to fetch sharding key via FK relationships. + if parser.is_sharded && parser.sharded_table.is_none() { + if let Some(join) = db_schema.table_sharded_join(table, "", None) { + // Find the FK column position in COPY columns + if let Some(fk_col_name) = join.fk_column() { + // Get column names from COPY statement + let copy_columns: Vec = stmt + .attlist + .iter() + .filter_map(|node| { + if let Some(NodeEnum::String(s)) = &node.node { + Some(s.sval.clone()) + } else { + None + } + }) + .collect(); + + if let Some(pos) = copy_columns.iter().position(|c| c == fk_col_name) { + parser.fk_column = Some(pos); + parser.fk_lookup = Self::build_fk_lookup(&join, cluster); + } + } + } } - parser.columns = columns.len(); + parser.columns = stmt.attlist.len(); for option in &stmt.options { if let Some(NodeEnum::DefElem(ref elem)) = option.node { @@ -200,9 +256,29 @@ impl CopyParser { self.delimiter.unwrap_or('\t') } + /// Build an FK lookup from a join and cluster. + fn build_fk_lookup( + join: &crate::backend::schema::Join, + cluster: &Cluster, + ) -> Option>> { + Some(Arc::new(Mutex::new(FkLookup::new( + join.clone(), + cluster.clone(), + )))) + } + + /// Override the cluster used for FK lookups (e.g., during replication). + /// Rebuilds the FK lookup with the new cluster. + pub fn set_fk_lookup_cluster(&mut self, cluster: &Cluster) { + if let Some(ref fk_lookup) = self.fk_lookup { + let join = fk_lookup.blocking_lock().join().clone(); + self.fk_lookup = Self::build_fk_lookup(&join, cluster); + } + } + /// Split CopyData (F) messages into multiple CopyData (F) messages /// with shard numbers. - pub fn shard(&mut self, data: &[CopyData]) -> Result, Error> { + pub async fn shard(&mut self, data: &[CopyData]) -> Result, Error> { let mut rows = vec![]; for row in data { @@ -250,6 +326,25 @@ impl CopyParser { } } else if let Some(schema_shard) = self.schema_shard.clone() { schema_shard + } else if let Some(fk_lookup) = self.fk_lookup.as_ref() { + if let Some(fk_col) = self.fk_column { + let fk_value = record.get(fk_col).ok_or(Error::NoShardingColumn)?; + if fk_value == self.null_string { + Shard::All + } else { + match fk_lookup + .lock() + .await + .lookup(ShardingData::Text(fk_value)) + .await + { + Ok(shard) => shard, + Err(_) => Shard::All, + } + } + } else { + Shard::All + } } else { Shard::All }; @@ -282,7 +377,7 @@ impl CopyParser { let key = tuple .get(self.sharded_column) .ok_or(Error::NoShardingColumn)?; - if let Data::Column(key) = key { + if let BinaryData::Column(key) = key { let ctx = ContextBuilder::new(table) .data(&key[..]) .shards(self.sharding_schema.shards) @@ -294,6 +389,25 @@ impl CopyParser { } } else if let Some(schema_shard) = self.schema_shard.clone() { schema_shard + } else if let Some(fk_lookup) = self.fk_lookup.as_ref() { + if let Some(fk_col) = self.fk_column { + let fk_value = tuple.get(fk_col).ok_or(Error::NoShardingColumn)?; + if let BinaryData::Column(fk_bytes) = fk_value { + match fk_lookup + .lock() + .await + .lookup(ShardingData::Binary(&fk_bytes[..])) + .await + { + Ok(shard) => shard, + Err(_) => Shard::All, + } + } else { + Shard::All + } + } else { + Shard::All + } } else { Shard::All }; @@ -316,8 +430,8 @@ mod test { use super::*; - #[test] - fn test_copy_text() { + #[tokio::test] + async fn test_copy_text() { let copy = "COPY sharded (id, value) FROM STDIN"; let stmt = parse(copy).unwrap(); let stmt = stmt.protobuf.stmts.first().unwrap(); @@ -333,13 +447,13 @@ mod test { let one = CopyData::new("5\thello world\n".as_bytes()); let two = CopyData::new("10\thowdy mate\n".as_bytes()); - let sharded = copy.shard(&[one, two]).unwrap(); + let sharded = copy.shard(&[one, two]).await.unwrap(); assert_eq!(sharded[0].message().data(), b"5\thello world\n"); assert_eq!(sharded[1].message().data(), b"10\thowdy mate\n"); } - #[test] - fn test_copy_csv() { + #[tokio::test] + async fn test_copy_csv() { let copy = "COPY sharded (id, value) FROM STDIN CSV HEADER"; let stmt = parse(copy).unwrap(); let stmt = stmt.protobuf.stmts.first().unwrap(); @@ -357,7 +471,7 @@ mod test { let header = CopyData::new("id,value\n".as_bytes()); let one = CopyData::new("5,hello world\n".as_bytes()); let two = CopyData::new("10,howdy mate\n".as_bytes()); - let sharded = copy.shard(&[header, one, two]).unwrap(); + let sharded = copy.shard(&[header, one, two]).await.unwrap(); assert_eq!(sharded[0].message().data(), b"\"id\",\"value\"\n"); assert_eq!(sharded[1].message().data(), b"\"5\",\"hello world\"\n"); @@ -367,16 +481,16 @@ mod test { let partial_two = CopyData::new("\n1,2".as_bytes()); let partial_three = CopyData::new("\n".as_bytes()); - let sharded = copy.shard(&[partial_one]).unwrap(); + let sharded = copy.shard(&[partial_one]).await.unwrap(); assert!(sharded.is_empty()); - let sharded = copy.shard(&[partial_two]).unwrap(); + let sharded = copy.shard(&[partial_two]).await.unwrap(); assert_eq!(sharded[0].message().data(), b"\"11\",\"howdy partner\"\n"); - let sharded = copy.shard(&[partial_three]).unwrap(); + let sharded = copy.shard(&[partial_three]).await.unwrap(); assert_eq!(sharded[0].message().data(), b"\"1\",\"2\"\n"); } - #[test] - fn test_copy_csv_stream() { + #[tokio::test] + async fn test_copy_csv_stream() { let copy_data = CopyData::new(b"id,value\n1,test\n6,test6\n"); let copy = "COPY sharded (id, value) FROM STDIN CSV HEADER"; @@ -389,7 +503,7 @@ mod test { let mut copy = CopyParser::new(©, &Cluster::new_test(&config())).unwrap(); - let rows = copy.shard(&[copy_data]).unwrap(); + let rows = copy.shard(&[copy_data]).await.unwrap(); assert_eq!(rows.len(), 3); assert_eq!(rows[0].message(), CopyData::new(b"\"id\",\"value\"\n")); assert_eq!(rows[0].shard(), &Shard::All); @@ -399,8 +513,8 @@ mod test { assert_eq!(rows[2].shard(), &Shard::Direct(1)); } - #[test] - fn test_copy_csv_custom_null() { + #[tokio::test] + async fn test_copy_csv_custom_null() { let copy = "COPY sharded (id, value) FROM STDIN CSV NULL 'NULL'"; let stmt = parse(copy).unwrap(); let stmt = stmt.protobuf.stmts.first().unwrap(); @@ -415,7 +529,7 @@ mod test { assert!(!copy.headers); let data = CopyData::new("5,hello\n10,NULL\n15,world\n".as_bytes()); - let sharded = copy.shard(&[data]).unwrap(); + let sharded = copy.shard(&[data]).await.unwrap(); assert_eq!(sharded.len(), 3); assert_eq!(sharded[0].message().data(), b"\"5\",\"hello\"\n"); @@ -423,8 +537,8 @@ mod test { assert_eq!(sharded[2].message().data(), b"\"15\",\"world\"\n"); } - #[test] - fn test_copy_text_pg_dump_end_marker() { + #[tokio::test] + async fn test_copy_text_pg_dump_end_marker() { // pg_dump generates text format COPY with `\.` as end-of-copy marker. // This marker should be sent to all shards without extracting a sharding key. let copy = "COPY sharded (id, value) FROM STDIN"; @@ -441,7 +555,7 @@ mod test { let two = CopyData::new("6\tBob\n".as_bytes()); let end_marker = CopyData::new("\\.\n".as_bytes()); - let sharded = copy.shard(&[one, two, end_marker]).unwrap(); + let sharded = copy.shard(&[one, two, end_marker]).await.unwrap(); assert_eq!(sharded.len(), 3); assert_eq!(sharded[0].message().data(), b"1\tAlice\n"); assert_eq!(sharded[0].shard(), &Shard::Direct(0)); @@ -451,8 +565,8 @@ mod test { assert_eq!(sharded[2].shard(), &Shard::All); } - #[test] - fn test_copy_text_null_sharding_key() { + #[tokio::test] + async fn test_copy_text_null_sharding_key() { // pg_dump text format uses `\N` to represent NULL values. // When the sharding key is NULL, route to all shards. // When a non-sharding column is NULL, route normally based on the key. @@ -471,7 +585,7 @@ mod test { let three = CopyData::new("11\tCharlie\n".as_bytes()); let four = CopyData::new("6\t\\N\n".as_bytes()); - let sharded = copy.shard(&[one, two, three, four]).unwrap(); + let sharded = copy.shard(&[one, two, three, four]).await.unwrap(); assert_eq!(sharded.len(), 4); assert_eq!(sharded[0].message().data(), b"1\tAlice\n"); assert_eq!(sharded[0].shard(), &Shard::Direct(0)); @@ -483,8 +597,8 @@ mod test { assert_eq!(sharded[3].shard(), &Shard::Direct(1)); } - #[test] - fn test_copy_text_composite_type_sharded() { + #[tokio::test] + async fn test_copy_text_composite_type_sharded() { // Test the same composite type but with sharding enabled (using the sharded table from config) let copy = "COPY sharded (id, value) FROM STDIN"; let stmt = parse(copy).unwrap(); @@ -498,7 +612,7 @@ mod test { // Row where the value contains a composite type with commas and quotes let row = CopyData::new(b"1\t(,Annapolis,Maryland,\"United States\",)\n"); - let sharded = copy.shard(&[row]).unwrap(); + let sharded = copy.shard(&[row]).await.unwrap(); assert_eq!(sharded.len(), 1); @@ -510,8 +624,8 @@ mod test { ); } - #[test] - fn test_copy_explicit_text_format() { + #[tokio::test] + async fn test_copy_explicit_text_format() { // Test with explicit FORMAT text (like during resharding) let copy = r#"COPY "public"."entity_values" ("id", "value_location") FROM STDIN WITH (FORMAT text)"#; let stmt = parse(copy).unwrap(); @@ -532,7 +646,7 @@ mod test { // Row with composite type let row = CopyData::new(b"1\t(,Annapolis,Maryland,\"United States\",)\n"); - let sharded = copy.shard(&[row]).unwrap(); + let sharded = copy.shard(&[row]).await.unwrap(); assert_eq!(sharded.len(), 1); assert_eq!( @@ -542,8 +656,8 @@ mod test { ); } - #[test] - fn test_copy_binary() { + #[tokio::test] + async fn test_copy_binary() { let copy = "COPY sharded (id, value) FROM STDIN (FORMAT 'binary')"; let stmt = parse(copy).unwrap(); let stmt = stmt.protobuf.stmts.first().unwrap(); @@ -570,10 +684,383 @@ mod test { data.extend(b"yes"); data.extend((-1_i16).to_be_bytes()); let header = CopyData::new(data.as_slice()); - let sharded = copy.shard(&[header]).unwrap(); + let sharded = copy.shard(&[header]).await.unwrap(); assert_eq!(sharded.len(), 3); assert_eq!(sharded[0].message().data(), &data[..19]); // Header is 19 bytes long. assert_eq!(sharded[1].message().data().len(), 2 + 4 + 8 + 4 + 3); assert_eq!(sharded[2].message().data(), (-1_i16).to_be_bytes()); } + + #[test] + fn test_copy_fk_lookup_setup() { + use crate::backend::schema::test_helpers::prelude::*; + + // Create schema with FK relationships: + // users (id PK, customer_id - sharding key) + // orders (id PK, user_id FK -> users.id) + // order_items (id PK, order_id FK -> orders.id) + let mut db_schema = schema() + .relation( + table("users") + .oid(1001) + .column(pk("id")) + .column(col("customer_id")), + ) + .relation( + table("orders") + .oid(1002) + .column(pk("id")) + .column(fk("user_id", "users", "id")), + ) + .relation( + table("order_items") + .oid(1003) + .column(pk("id")) + .column(fk("order_id", "orders", "id")), + ) + .build(); + + let sharding_schema = sharding().sharded_table("users", "customer_id").build(); + + // Compute joins so tables are marked as sharded and join paths are stored + db_schema.computed_sharded_joins(&sharding_schema); + + // Verify users has sharding key directly (no join stored, but is_sharded = true) + let users_table = Table { + name: "users", + schema: Some("public"), + alias: None, + }; + let users = db_schema.table(users_table, "", None).unwrap(); + assert!(users.is_sharded); + assert!(db_schema.get_sharded_join(users).is_none()); + + // Verify orders has a join to users (is_sharded = true, join stored) + let orders_table = Table { + name: "orders", + schema: Some("public"), + alias: None, + }; + let orders = db_schema.table(orders_table, "", None).unwrap(); + assert!(orders.is_sharded); + let orders_join = db_schema.get_sharded_join(orders); + assert!(orders_join.is_some()); + let orders_join = orders_join.unwrap(); + assert_eq!(orders_join.path.len(), 1); + assert_eq!(orders_join.fk_column(), Some("user_id")); + // Query should directly look up sharding key from users + assert!( + orders_join.query.contains("users"), + "query should reference users table" + ); + assert!( + orders_join.query.contains("customer_id"), + "query should select customer_id" + ); + assert!( + !orders_join.query.contains("orders"), + "query should NOT reference orders table" + ); + + // Verify order_items has a 2-hop join to users + let order_items_table = Table { + name: "order_items", + schema: Some("public"), + alias: None, + }; + let order_items = db_schema.table(order_items_table, "", None).unwrap(); + assert!(order_items.is_sharded); + let order_items_join = db_schema.get_sharded_join(order_items); + assert!(order_items_join.is_some()); + let order_items_join = order_items_join.unwrap(); + assert_eq!(order_items_join.path.len(), 2); + assert_eq!(order_items_join.fk_column(), Some("order_id")); + // Query should start from orders (first hop target), not order_items + assert!( + order_items_join.query.contains("orders"), + "query should reference orders table" + ); + assert!( + order_items_join.query.contains("users"), + "query should reference users table" + ); + assert!( + !order_items_join.query.contains("order_items"), + "query should NOT reference order_items table" + ); + } + + /// Test that CopyParser correctly sets up FK lookup when parsing COPY for a table + /// that doesn't have the sharding key directly but has an FK path to it. + #[tokio::test] + async fn test_copy_parser_fk_lookup_initialization() { + use crate::backend::pool::test::pool; + use crate::backend::pool::Request; + use crate::backend::schema::Schema; + use crate::backend::Cluster; + use crate::config::{DataType, Hasher, ShardedTable}; + + let pool = pool(); + let mut conn = pool.get(&Request::default()).await.unwrap(); + + // Create FK tables in the database + conn.execute("DROP TABLE IF EXISTS pgdog.fk_copy_orders, pgdog.fk_copy_users") + .await + .unwrap(); + + conn.execute( + "CREATE TABLE pgdog.fk_copy_users ( + id BIGINT PRIMARY KEY, + customer_id BIGINT NOT NULL + )", + ) + .await + .unwrap(); + + conn.execute( + "CREATE TABLE pgdog.fk_copy_orders ( + id BIGINT PRIMARY KEY, + user_id BIGINT REFERENCES pgdog.fk_copy_users(id) + )", + ) + .await + .unwrap(); + + // Load schema from database + let mut db_schema = Schema::load(&mut conn).await.unwrap(); + + // Create sharded tables config with customer_id as sharding key for fk_copy_users + let sharded_tables = crate::backend::ShardedTables::new( + vec![ShardedTable { + database: "pgdog".into(), + name: Some("fk_copy_users".into()), + schema: Some("pgdog".into()), + column: "customer_id".into(), + primary: true, + data_type: DataType::Bigint, + hasher: Hasher::Postgres, + ..Default::default() + }], + vec![], + false, // omnisharded_sticky + pgdog_config::SystemCatalogsBehavior::default(), + ); + + // Create sharding schema from the config + let sharding_schema = crate::backend::ShardingSchema { + shards: 2, + tables: sharded_tables.clone(), + schemas: crate::backend::replication::ShardedSchemas::new(vec![]), + ..Default::default() + }; + + // Compute joins so tables are marked as sharded + db_schema.computed_sharded_joins(&sharding_schema); + + // Verify the schema was loaded correctly with FK relationships + let orders_table = Table { + name: "fk_copy_orders", + schema: Some("pgdog"), + alias: None, + }; + let orders = db_schema.table(orders_table, "", None); + assert!(orders.is_some(), "fk_copy_orders should be in schema"); + assert!( + orders.unwrap().is_sharded, + "fk_copy_orders should be marked as sharded" + ); + + let orders_join = orders.and_then(|r| db_schema.get_sharded_join(r)); + assert!( + orders_join.is_some(), + "fk_copy_orders should have a join path" + ); + + // Create a test cluster with custom sharding config and schema + let cluster = Cluster::new_test_with_sharding(sharded_tables, db_schema); + cluster.launch(); + + // Parse COPY statement for orders table (has FK, no direct sharding key) + let copy = "COPY pgdog.fk_copy_orders (id, user_id) FROM STDIN"; + let stmt = parse(copy).unwrap(); + let stmt = stmt.protobuf.stmts.first().unwrap(); + let copy_stmt = match stmt.stmt.clone().unwrap().node.unwrap() { + NodeEnum::CopyStmt(copy) => copy, + _ => panic!("not a copy"), + }; + + let mut parser = CopyParser::new(©_stmt, &cluster).unwrap(); + + // Verify FK lookup is set up + assert!( + parser.fk_lookup.is_some(), + "fk_lookup should be set for table with FK path to sharding key" + ); + assert_eq!( + parser.fk_column, + Some(1), + "fk_column should be position 1 (user_id is second column)" + ); + assert!( + parser.is_sharded, + "fk_copy_orders table should be marked as sharded" + ); + assert!( + parser.sharded_table.is_none(), + "fk_copy_orders table doesn't have sharding key directly" + ); + + // Insert users with known customer_id values for FK lookup + for i in 1i64..=10 { + conn.execute(&format!( + "INSERT INTO pgdog.fk_copy_users (id, customer_id) VALUES ({}, {})", + i, + i * 100 // customer_id determines shard + )) + .await + .unwrap(); + } + + // Create COPY text format data - orders referencing users + // Format: order_iduser_id + use crate::net::messages::CopyData; + + let copy_rows: String = (1i64..=10) + .map(|i| format!("{}\t{}\n", i * 10, i)) // order_id, user_id (FK) + .collect(); + + let copy_data = CopyData::new(copy_rows.as_bytes()); + let rows = parser.shard(&[copy_data]).await.unwrap(); + + // Verify each row was routed to a specific shard (not Shard::All) + assert_eq!(rows.len(), 10, "should have 10 rows"); + for (i, row) in rows.iter().enumerate() { + assert!( + matches!(row.shard(), super::Shard::Direct(_)), + "row {} should be routed to a specific shard, got {:?}", + i, + row.shard() + ); + } + + // Clean up + conn.execute("DROP TABLE IF EXISTS pgdog.fk_copy_orders, pgdog.fk_copy_users") + .await + .unwrap(); + cluster.shutdown(); + } + + /// Test that CopyParser does NOT set up FK lookup when table has sharding key directly. + #[tokio::test] + async fn test_copy_parser_direct_sharding_key() { + use crate::backend::pool::test::pool; + use crate::backend::pool::Request; + use crate::backend::schema::Schema; + use crate::backend::Cluster; + use crate::config::{DataType, Hasher, ShardedTable}; + + let pool = pool(); + let mut conn = pool.get(&Request::default()).await.unwrap(); + + // Create FK tables in the database + conn.execute("DROP TABLE IF EXISTS pgdog.fk_copy_orders, pgdog.fk_copy_users") + .await + .unwrap(); + + conn.execute( + "CREATE TABLE pgdog.fk_copy_users ( + id BIGINT PRIMARY KEY, + customer_id BIGINT NOT NULL + )", + ) + .await + .unwrap(); + + // Load schema from database + let mut db_schema = Schema::load(&mut conn).await.unwrap(); + + // Create sharded tables config with customer_id as sharding key for fk_copy_users + let sharded_tables = crate::backend::ShardedTables::new( + vec![ShardedTable { + database: "pgdog".into(), + name: Some("fk_copy_users".into()), + schema: Some("pgdog".into()), + column: "customer_id".into(), + primary: true, + data_type: DataType::Bigint, + hasher: Hasher::Postgres, + ..Default::default() + }], + vec![], + false, // omnisharded_sticky + pgdog_config::SystemCatalogsBehavior::default(), + ); + + // Create sharding schema from the config + let sharding_schema = crate::backend::ShardingSchema { + shards: 2, + tables: sharded_tables.clone(), + schemas: crate::backend::replication::ShardedSchemas::new(vec![]), + ..Default::default() + }; + + // Compute joins so tables are marked as sharded + db_schema.computed_sharded_joins(&sharding_schema); + + // Create a test cluster with custom sharding config and schema + let cluster = Cluster::new_test_with_sharding(sharded_tables, db_schema); + + // Parse COPY statement for users table (has sharding key directly) + let copy = "COPY pgdog.fk_copy_users (id, customer_id) FROM STDIN"; + let stmt = parse(copy).unwrap(); + let stmt = stmt.protobuf.stmts.first().unwrap(); + let copy_stmt = match stmt.stmt.clone().unwrap().node.unwrap() { + NodeEnum::CopyStmt(copy) => copy, + _ => panic!("not a copy"), + }; + + let mut parser = CopyParser::new(©_stmt, &cluster).unwrap(); + + // Verify direct sharding is used, not FK lookup + assert!( + parser.fk_lookup.is_none(), + "fk_lookup should NOT be set for table with direct sharding key" + ); + assert!( + parser.is_sharded, + "fk_copy_users table should be marked as sharded" + ); + assert!( + parser.sharded_table.is_some(), + "fk_copy_users table has sharding key directly" + ); + assert_eq!(parser.sharded_column, 1, "customer_id is at position 1"); + + // Create COPY text format data - users with customer_id (sharding key) + // Format: idcustomer_id + use crate::net::messages::CopyData; + + let copy_rows: String = (1i64..=10) + .map(|i| format!("{}\t{}\n", i, i * 100)) // id, customer_id (sharding key) + .collect(); + + let copy_data = CopyData::new(copy_rows.as_bytes()); + let rows = parser.shard(&[copy_data]).await.unwrap(); + + // Verify each row was routed to a specific shard (not Shard::All) + assert_eq!(rows.len(), 10, "should have 10 rows"); + for (i, row) in rows.iter().enumerate() { + assert!( + matches!(row.shard(), super::Shard::Direct(_)), + "row {} should be routed to a specific shard, got {:?}", + i, + row.shard() + ); + } + + // Clean up + conn.execute("DROP TABLE IF EXISTS pgdog.fk_copy_orders, pgdog.fk_copy_users") + .await + .unwrap(); + } } diff --git a/pgdog/src/frontend/router/parser/error.rs b/pgdog/src/frontend/router/parser/error.rs index bddf059fc..24569266e 100644 --- a/pgdog/src/frontend/router/parser/error.rs +++ b/pgdog/src/frontend/router/parser/error.rs @@ -16,6 +16,9 @@ pub enum Error { #[error("no sharding column in CSV")] NoShardingColumn, + #[error("no primary key column in row")] + NoPrimaryKeyColumn, + #[error("{0}")] Net(#[from] crate::net::Error), diff --git a/pgdog/src/frontend/router/parser/mod.rs b/pgdog/src/frontend/router/parser/mod.rs index 449dee099..7945f433a 100644 --- a/pgdog/src/frontend/router/parser/mod.rs +++ b/pgdog/src/frontend/router/parser/mod.rs @@ -57,7 +57,7 @@ pub use rewrite::{ pub use route::{Route, Shard, ShardWithPriority, ShardsWithPriority}; pub use schema::Schema; pub use sequence::{OwnedSequence, Sequence}; -pub use statement::{SchemaLookupContext, StatementParser}; +pub use statement::{CopySharding, SchemaLookupContext, StatementParser}; pub use table::{OwnedTable, Table}; pub use tuple::Tuple; pub use value::Value; diff --git a/pgdog/src/frontend/router/parser/statement.rs b/pgdog/src/frontend/router/parser/statement.rs index 7e962ddf8..8276926a8 100644 --- a/pgdog/src/frontend/router/parser/statement.rs +++ b/pgdog/src/frontend/router/parser/statement.rs @@ -2,7 +2,8 @@ use std::collections::{HashMap, HashSet}; use pg_query::{ protobuf::{ - AExprKind, BoolExprType, DeleteStmt, InsertStmt, RangeVar, RawStmt, SelectStmt, UpdateStmt, + AExprKind, BoolExprType, CopyStmt, DeleteStmt, InsertStmt, RangeVar, RawStmt, SelectStmt, + UpdateStmt, }, Node, NodeEnum, }; @@ -165,6 +166,14 @@ enum Statement<'a> { Update(&'a UpdateStmt), Delete(&'a DeleteStmt), Insert(&'a InsertStmt), + Copy(&'a CopyStmt), +} + +#[derive(Debug, Clone)] +pub struct CopySharding<'a> { + pub key_position: Option, + pub key_table: Option<&'a ShardedTable>, + pub primary_key_position: Option, } /// Context for looking up table columns from the database schema. @@ -281,6 +290,11 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { Self::new(Statement::Insert(stmt), bind, schema, recorder) } + /// COPY only uses simple protocol (no bind) and doesn't support EXPLAIN. + pub fn from_copy(stmt: &'a CopyStmt, schema: &'b ShardingSchema) -> Self { + Self::new(Statement::Copy(stmt), None, schema, None) + } + /// Record a sharding key match. fn record_sharding_key(&mut self, shard: &Shard, column: Column<'_>, value: &Value<'_>) { self.hooks @@ -313,6 +327,7 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { Some(NodeEnum::UpdateStmt(stmt)) => Ok(Self::from_update(stmt, bind, schema, recorder)), Some(NodeEnum::DeleteStmt(stmt)) => Ok(Self::from_delete(stmt, bind, schema, recorder)), Some(NodeEnum::InsertStmt(stmt)) => Ok(Self::from_insert(stmt, bind, schema, recorder)), + Some(NodeEnum::CopyStmt(stmt)) => Ok(Self::from_copy(stmt, schema)), _ => Err(Error::NotASelect), } } @@ -329,6 +344,8 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { Statement::Update(stmt) => self.shard_update(stmt), Statement::Delete(stmt) => self.shard_delete(stmt), Statement::Insert(stmt) => self.shard_insert(stmt), + // COPY doesn't have values in the SQL statement - routing happens per-row + Statement::Copy(_) => Ok(None), }?; // Key-based sharding succeeded @@ -360,9 +377,9 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { Ok(None) } - /// Check that the query references a table that contains a sharded - /// column. This check is needed in case sharded tables config - /// doesn't specify a table name and should short-circuit if it does. + /// Check that the query references a table that can participate in sharded routing. + /// Uses the precomputed `is_sharded` flag on relations which accounts for both + /// direct sharding keys and FK join paths to sharded tables. pub fn is_sharded( &mut self, db_schema: &Schema, @@ -374,42 +391,114 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { return false; } - let sharded_tables = self.schema.tables.tables(); + for table in self.tables() { + if let Some(relation) = db_schema.table(*table, user, search_path) { + if relation.is_sharded { + return true; + } + } + } - // Separate configs with explicit table names from those without - let (named, nameless): (Vec<_>, Vec<_>) = - sharded_tables.iter().partition(|t| t.name.is_some()); + false + } - for table in self.tables() { - // Check named sharded table configs (fast path, no schema lookup needed) - for config in &named { - if let Some(ref name) = config.name { - if table.name == name { - // Also check schema match if specified in config - if let Some(ref config_schema) = config.schema { - if table.schema != Some(config_schema.as_str()) { - continue; - } - } - return true; - } + /// Get column names for a COPY statement with primary key info. + /// Returns Vec<(is_primary_key, column_name)>. + fn get_copy_columns(&self, stmt: &CopyStmt, table: &Table<'_>) -> Vec<(bool, String)> { + // First try to get columns from the COPY statement itself + let col_names: Vec = stmt + .attlist + .iter() + .filter_map(|s| Column::from_string(s).ok()) + .map(|c| c.name.to_string()) + .collect(); + + if !col_names.is_empty() { + // Look up primary key info from schema if available + if let Some(ref schema_lookup) = self.schema_lookup { + if let Some(relation) = schema_lookup.db_schema.table( + *table, + schema_lookup.user, + schema_lookup.search_path, + ) { + return col_names + .into_iter() + .map(|name| { + let is_pk = relation + .columns() + .get(&name) + .map(|c| c.is_primary_key) + .unwrap_or(false); + (is_pk, name) + }) + .collect(); } } + // No schema lookup, return columns without PK info + return col_names.into_iter().map(|name| (false, name)).collect(); + } - // Check nameless configs by looking up the table in the db schema - // to see if it has the sharding column - if !nameless.is_empty() { - if let Some(relation) = db_schema.table(*table, user, search_path) { - for config in &nameless { - if relation.has_column(&config.column) { - return true; - } - } + // No columns specified in COPY, try to look them up from schema + if let Some(ref schema_lookup) = self.schema_lookup { + if let Some(relation) = + schema_lookup + .db_schema + .table(*table, schema_lookup.user, schema_lookup.search_path) + { + return relation + .columns() + .iter() + .map(|(name, col)| (col.is_primary_key, name.clone())) + .collect(); + } + } + + vec![] + } + + pub fn copy_sharding_key(&self) -> Option> { + let stmt = match self.stmt { + Statement::Copy(stmt) => stmt, + _ => return None, + }; + + let relation = stmt.relation.as_ref()?; + let table = Table::from(relation); + + // Get column names with PK info from COPY statement or schema lookup + let columns = self.get_copy_columns(stmt, &table); + + // Find the primary key position + let primary_key_position = columns.iter().position(|(is_pk, _)| *is_pk); + + let sharded_tables = self.schema.tables.tables(); + + // Check tables with explicit name first + for sharded in sharded_tables.iter().filter(|t| t.name.is_some()) { + if sharded.name.as_deref() == Some(table.name) { + if let Some(position) = columns.iter().position(|(_, name)| name == &sharded.column) + { + return Some(CopySharding { + key_position: Some(position), + key_table: Some(sharded), + primary_key_position, + }); } } } - false + // Check tables without name (column-only config) + for sharded in sharded_tables.iter().filter(|t| t.name.is_none()) { + if let Some(position) = columns.iter().position(|(_, name)| name == &sharded.column) { + return Some(CopySharding { + key_position: Some(position), + key_table: Some(sharded), + primary_key_position, + }); + } + } + + None } /// Extract all tables referenced in the statement. @@ -420,6 +509,7 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { Statement::Update(stmt) => self.extract_tables_from_update(stmt, &mut tables), Statement::Delete(stmt) => self.extract_tables_from_delete(stmt, &mut tables), Statement::Insert(stmt) => self.extract_tables_from_insert(stmt, &mut tables), + Statement::Copy(stmt) => self.extract_tables_from_copy(stmt, &mut tables), } tables } @@ -544,6 +634,12 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { } } + fn extract_tables_from_copy(&self, stmt: &'a CopyStmt, tables: &mut Vec>) { + if let Some(ref relation) = stmt.relation { + tables.push(Table::from(relation)); + } + } + fn extract_tables_from_node(&self, node: &'a Node, tables: &mut Vec>) { match &node.node { Some(NodeEnum::RangeVar(range_var)) => { @@ -2565,7 +2661,9 @@ mod test { fn run_is_sharded_test(stmt: &str) -> bool { let schema = make_omnisharded_sharding_schema(); - let db_schema = make_omnisharded_db_schema(); + let mut db_schema = make_omnisharded_db_schema(); + // Compute joins to set is_sharded flags on relations + db_schema.computed_sharded_joins(&schema); let raw = pg_query::parse(stmt) .unwrap() .protobuf @@ -2644,4 +2742,311 @@ mod test { "DELETE from omnisharded table should not be sharded" ); } + + /// Test helper for FK join sharding detection. + /// Creates a schema with users (sharding key) and orders (FK to users). + fn run_fk_join_sharding_test(stmt: &str) -> bool { + use crate::backend::schema::test_helpers::prelude::*; + + let mut db_schema = schema() + .relation( + table("users") + .oid(1001) + .column(pk("id")) + .column(col("user_id")), + ) + .relation( + table("orders") + .oid(1002) + .column(pk("id")) + .column(fk("user_id", "users", "id")), + ) + .relation(table("isolated").oid(1003).column(pk("id"))) + .build(); + + let sharding_schema = sharding().sharded_table("users", "user_id").build(); + db_schema.computed_sharded_joins(&sharding_schema); + + let raw = pg_query::parse(stmt) + .unwrap() + .protobuf + .stmts + .first() + .cloned() + .unwrap(); + + let mut parser = StatementParser::from_raw(&raw, None, &sharding_schema, None).unwrap(); + parser.is_sharded(&db_schema, "test_user", None) + } + + #[test] + fn test_is_sharded_detects_fk_join_to_sharded_table() { + // orders doesn't have sharding key, but FK joins to users which does + let result = run_fk_join_sharding_test("SELECT * FROM orders WHERE id = 1"); + assert!(result, "orders should be sharded via FK join to users"); + } + + #[test] + fn test_is_sharded_direct_table_with_sharding_key() { + // users has the sharding key directly + let result = run_fk_join_sharding_test("SELECT * FROM users WHERE id = 1"); + assert!(result, "users has sharding key directly"); + } + + #[test] + fn test_is_sharded_table_without_fk_path() { + // isolated has no FK path to any sharded table + let result = run_fk_join_sharding_test("SELECT * FROM isolated WHERE id = 1"); + assert!(!result, "isolated has no FK path to sharded table"); + } + + #[test] + fn test_is_sharded_join_query_with_fk_table() { + // Query joining orders (FK) with users (sharded) + let result = + run_fk_join_sharding_test("SELECT * FROM orders o JOIN users u ON o.user_id = u.id"); + assert!(result, "join query with sharded table should be sharded"); + } + + #[test] + fn test_is_sharded_insert_on_fk_table() { + let result = run_fk_join_sharding_test("INSERT INTO orders (id, user_id) VALUES (1, 100)"); + assert!(result, "insert on FK table should be sharded"); + } + + #[test] + fn test_is_sharded_update_on_fk_table() { + let result = run_fk_join_sharding_test("UPDATE orders SET user_id = 100 WHERE id = 1"); + assert!(result, "update on FK table should be sharded"); + } + + #[test] + fn test_is_sharded_delete_on_fk_table() { + let result = run_fk_join_sharding_test("DELETE FROM orders WHERE id = 1"); + assert!(result, "delete on FK table should be sharded"); + } + + // COPY sharding key tests + + #[test] + fn test_copy_sharding_key_found() { + let sharding_schema = ShardingSchema { + shards: 3, + tables: ShardedTables::new( + vec![ShardedTable { + column: "id".into(), + name: Some("users".into()), + ..Default::default() + }], + vec![], + false, + SystemCatalogsBehavior::default(), + ), + ..Default::default() + }; + + let raw = pg_query::parse("COPY users (name, id, email) FROM STDIN") + .unwrap() + .protobuf + .stmts + .first() + .cloned() + .unwrap(); + + let parser = StatementParser::from_raw(&raw, None, &sharding_schema, None).unwrap(); + let result = parser.copy_sharding_key(); + + assert!(result.is_some(), "should find sharding key"); + let sharding = result.unwrap(); + assert_eq!( + sharding.key_position, + Some(1), + "id is at position 1 (0-indexed)" + ); + assert_eq!(sharding.key_table.unwrap().column, "id"); + } + + #[test] + fn test_copy_sharding_key_first_position() { + let sharding_schema = ShardingSchema { + shards: 3, + tables: ShardedTables::new( + vec![ShardedTable { + column: "tenant_id".into(), + name: Some("orders".into()), + ..Default::default() + }], + vec![], + false, + SystemCatalogsBehavior::default(), + ), + ..Default::default() + }; + + let raw = pg_query::parse("COPY orders (tenant_id, amount, status) FROM STDIN CSV") + .unwrap() + .protobuf + .stmts + .first() + .cloned() + .unwrap(); + + let parser = StatementParser::from_raw(&raw, None, &sharding_schema, None).unwrap(); + let result = parser.copy_sharding_key(); + + assert!(result.is_some()); + let sharding = result.unwrap(); + assert_eq!(sharding.key_position, Some(0), "tenant_id is at position 0"); + } + + #[test] + fn test_copy_sharding_key_not_found() { + let sharding_schema = ShardingSchema { + shards: 3, + tables: ShardedTables::new( + vec![ShardedTable { + column: "user_id".into(), + name: Some("users".into()), + ..Default::default() + }], + vec![], + false, + SystemCatalogsBehavior::default(), + ), + ..Default::default() + }; + + // COPY on a different table + let raw = pg_query::parse("COPY orders (id, amount) FROM STDIN") + .unwrap() + .protobuf + .stmts + .first() + .cloned() + .unwrap(); + + let parser = StatementParser::from_raw(&raw, None, &sharding_schema, None).unwrap(); + let result = parser.copy_sharding_key(); + + assert!( + result.is_none(), + "should not find sharding key for different table" + ); + } + + #[test] + fn test_copy_sharding_key_column_only_config() { + // Column-only sharding config (no table name specified) + let sharding_schema = ShardingSchema { + shards: 3, + tables: ShardedTables::new( + vec![ShardedTable { + column: "tenant_id".into(), + // No table name - matches any table with this column + ..Default::default() + }], + vec![], + false, + SystemCatalogsBehavior::default(), + ), + ..Default::default() + }; + + let raw = pg_query::parse("COPY any_table (id, tenant_id, data) FROM STDIN") + .unwrap() + .protobuf + .stmts + .first() + .cloned() + .unwrap(); + + let parser = StatementParser::from_raw(&raw, None, &sharding_schema, None).unwrap(); + let result = parser.copy_sharding_key(); + + assert!(result.is_some(), "column-only config should match"); + let sharding = result.unwrap(); + assert_eq!(sharding.key_position, Some(1), "tenant_id is at position 1"); + } + + #[test] + fn test_copy_sharding_key_non_copy_returns_none() { + let sharding_schema = ShardingSchema { + shards: 3, + tables: ShardedTables::new( + vec![ShardedTable { + column: "id".into(), + name: Some("users".into()), + ..Default::default() + }], + vec![], + false, + SystemCatalogsBehavior::default(), + ), + ..Default::default() + }; + + let raw = pg_query::parse("SELECT * FROM users WHERE id = 1") + .unwrap() + .protobuf + .stmts + .first() + .cloned() + .unwrap(); + + let parser = StatementParser::from_raw(&raw, None, &sharding_schema, None).unwrap(); + let result = parser.copy_sharding_key(); + + assert!(result.is_none(), "non-COPY statement should return None"); + } + + #[test] + fn test_copy_sharding_key_without_column_list() { + // COPY without explicit column list should find sharding key from schema + let sharding_schema = ShardingSchema { + shards: 3, + tables: ShardedTables::new( + vec![ShardedTable { + column: "id".into(), + name: Some("sharded".into()), + ..Default::default() + }], + vec![], + false, + SystemCatalogsBehavior::default(), + ), + ..Default::default() + }; + + let db_schema = make_test_schema_with_relation(); + let schema_lookup = SchemaLookupContext { + db_schema: &db_schema, + user: "test", + search_path: None, + }; + + let raw = pg_query::parse("COPY sharded FROM STDIN") + .unwrap() + .protobuf + .stmts + .first() + .cloned() + .unwrap(); + + let parser = StatementParser::from_raw(&raw, None, &sharding_schema, None) + .unwrap() + .with_schema_lookup(schema_lookup); + let result = parser.copy_sharding_key(); + + assert!( + result.is_some(), + "should find sharding key from schema when COPY has no column list" + ); + let sharding = result.unwrap(); + assert_eq!( + sharding.key_position, + Some(0), + "id is at position 0 (first column in schema)" + ); + assert_eq!(sharding.key_table.unwrap().column, "id"); + } } diff --git a/pgdog/src/frontend/router/sharding/tables.rs b/pgdog/src/frontend/router/sharding/tables.rs index aaa7120c3..179206ec7 100644 --- a/pgdog/src/frontend/router/sharding/tables.rs +++ b/pgdog/src/frontend/router/sharding/tables.rs @@ -1,14 +1,4 @@ -use crate::{ - backend::ShardingSchema, - config::ShardedTable, - frontend::router::parser::{Column, Table}, -}; - -#[derive(Debug)] -pub struct Key<'a> { - pub table: &'a ShardedTable, - pub position: usize, -} +use crate::{backend::ShardingSchema, config::ShardedTable, frontend::router::parser::Table}; pub struct Tables<'a> { schema: &'a ShardingSchema, @@ -20,49 +10,11 @@ impl<'a> Tables<'a> { } pub(crate) fn sharded(&'a self, table: Table) -> Option<&'a ShardedTable> { - let tables = self.schema.tables().tables(); - - let sharded = tables + self.schema + .tables() + .tables() .iter() .filter(|table| table.name.is_some()) - .find(|t| t.name.as_deref() == Some(table.name)); - - sharded - } - - pub(crate) fn key(&'a self, table: Table, columns: &'a [Column]) -> Option> { - let tables = self.schema.tables().tables(); - - // Check tables with name first. - let sharded = tables - .iter() - .filter(|table| table.name.is_some()) - .find(|t| t.name.as_deref() == Some(table.name)); - - if let Some(sharded) = sharded { - if let Some(position) = columns.iter().position(|col| col.name == sharded.column) { - return Some(Key { - table: sharded, - position, - }); - } - } - - // Check tables without name. - let key: Option<(&'a ShardedTable, Option)> = tables - .iter() - .filter(|table| table.name.is_none()) - .map(|t| (t, columns.iter().position(|col| col.name == t.column))) - .find(|t| t.1.is_some()); - if let Some(key) = key { - if let Some(position) = key.1 { - return Some(Key { - table: key.0, - position, - }); - } - } - - None + .find(|t| t.name.as_deref() == Some(table.name)) } } diff --git a/pgdog/src/frontend/router/sharding/value.rs b/pgdog/src/frontend/router/sharding/value.rs index 7a77a628a..1e6c10a83 100644 --- a/pgdog/src/frontend/router/sharding/value.rs +++ b/pgdog/src/frontend/router/sharding/value.rs @@ -5,7 +5,7 @@ use uuid::Uuid; use super::{Error, Hasher}; use crate::{ config::DataType, - net::{Format, FromDataType, ParameterWithFormat, Vector}, + net::{bind::Parameter, Format, FromDataType, ParameterWithFormat, Vector}, }; use bytes::Bytes; @@ -16,6 +16,18 @@ pub enum Data<'a> { Integer(i64), } +impl Data<'_> { + /// Get parameter with format encoding. + pub(crate) fn parameter_with_format(&self) -> (Parameter, Format) { + let format = match self { + Self::Text(_) => Format::Text, + _ => Format::Binary, + }; + + (self.clone().into(), format) + } +} + impl<'a> From<&'a str> for Data<'a> { fn from(value: &'a str) -> Self { Self::Text(value) @@ -40,6 +52,25 @@ impl<'a> From<&'a Bytes> for Data<'a> { } } +impl From> for Parameter { + fn from(value: Data<'_>) -> Self { + match value { + Data::Text(text) => Parameter { + len: text.len() as i32, + data: text.as_bytes().to_vec().into(), + }, + Data::Binary(binary) => Parameter { + len: binary.len() as i32, + data: binary.to_vec().into(), + }, + Data::Integer(integer) => Parameter { + len: 8, + data: integer.to_be_bytes().to_vec().into(), + }, + } + } +} + #[derive(Debug, Clone)] pub struct Value<'a> { data_type: DataType, diff --git a/pgdog/src/net/messages/mod.rs b/pgdog/src/net/messages/mod.rs index 9604a22c3..bae36f64e 100644 --- a/pgdog/src/net/messages/mod.rs +++ b/pgdog/src/net/messages/mod.rs @@ -135,41 +135,53 @@ impl MemoryUsage for Message { impl std::fmt::Debug for Message { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + macro_rules! try_fmt { + ($expr:expr) => { + match $expr { + Ok(m) => m.fmt(f), + Err(_) => f + .debug_struct("Message") + .field("code", &self.code()) + .field("len", &self.payload().len()) + .finish(), + } + }; + } + match self.code() { - 'Q' => Query::from_bytes(self.payload()).unwrap().fmt(f), + 'Q' => try_fmt!(Query::from_bytes(self.payload())), 'D' => match self.source { - Source::Backend(_) => DataRow::from_bytes(self.payload()).unwrap().fmt(f), - Source::Frontend => Describe::from_bytes(self.payload()).unwrap().fmt(f), + Source::Backend(_) => try_fmt!(DataRow::from_bytes(self.payload())), + Source::Frontend => try_fmt!(Describe::from_bytes(self.payload())), }, - 'P' => Parse::from_bytes(self.payload()).unwrap().fmt(f), - 'B' => Bind::from_bytes(self.payload()).unwrap().fmt(f), + 'P' => try_fmt!(Parse::from_bytes(self.payload())), + 'B' => try_fmt!(Bind::from_bytes(self.payload())), 'S' => match self.source { Source::Frontend => f.debug_struct("Sync").finish(), - Source::Backend(_) => ParameterStatus::from_bytes(self.payload()).unwrap().fmt(f), + Source::Backend(_) => try_fmt!(ParameterStatus::from_bytes(self.payload())), }, - '1' => ParseComplete::from_bytes(self.payload()).unwrap().fmt(f), - '2' => BindComplete::from_bytes(self.payload()).unwrap().fmt(f), + '1' => try_fmt!(ParseComplete::from_bytes(self.payload())), + '2' => try_fmt!(BindComplete::from_bytes(self.payload())), '3' => f.debug_struct("CloseComplete").finish(), 'E' => match self.source { Source::Frontend => f.debug_struct("Execute").finish(), - Source::Backend(_) => ErrorResponse::from_bytes(self.payload()).unwrap().fmt(f), + Source::Backend(_) => try_fmt!(ErrorResponse::from_bytes(self.payload())), }, - 'T' => RowDescription::from_bytes(self.payload()).unwrap().fmt(f), - 'Z' => ReadyForQuery::from_bytes(self.payload()).unwrap().fmt(f), + 'T' => try_fmt!(RowDescription::from_bytes(self.payload())), + 'Z' => try_fmt!(ReadyForQuery::from_bytes(self.payload())), 'C' => match self.source { - Source::Backend(_) => CommandComplete::from_bytes(self.payload()).unwrap().fmt(f), - Source::Frontend => Close::from_bytes(self.payload()).unwrap().fmt(f), + Source::Backend(_) => try_fmt!(CommandComplete::from_bytes(self.payload())), + Source::Frontend => try_fmt!(Close::from_bytes(self.payload())), }, - 'd' => CopyData::from_bytes(self.payload()).unwrap().fmt(f), + 'd' => try_fmt!(CopyData::from_bytes(self.payload())), 'W' => f.debug_struct("CopyBothResponse").finish(), 'I' => f.debug_struct("EmptyQueryResponse").finish(), - 't' => ParameterDescription::from_bytes(self.payload()) - .unwrap() - .fmt(f), + 't' => try_fmt!(ParameterDescription::from_bytes(self.payload())), 'H' => f.debug_struct("Flush").finish(), _ => f .debug_struct("Message") - .field("payload", &self.payload()) + .field("code", &self.code()) + .field("len", &self.payload().len()) .finish(), } }