From 58960b1bd26ef9968c286036e3f3aa7cbb972df9 Mon Sep 17 00:00:00 2001 From: Zeeshan Ali Khan Date: Sat, 21 Mar 2026 14:32:28 +0100 Subject: [PATCH 1/2] =?UTF-8?q?=E2=9C=A8=20Add=20tokio=20feature=20for=20s?= =?UTF-8?q?td=20environment=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The macro now supports two mutually exclusive features: `embassy` (default, no_std) and `tokio` (std). Co-Authored-By: Claude Opus 4.6 (1M context) --- .github/workflows/rust.yaml | 12 +- CLAUDE.md | 67 ++- Cargo.lock | 54 ++ Cargo.toml | 15 +- README.md | 43 +- src/controller/item_impl.rs | 955 +++++++++++++++++++++++++++------- src/controller/item_struct.rs | 238 +++++++-- src/controller/mod.rs | 2 + src/lib.rs | 6 + tests/integration.rs | 389 ++++++++------ 10 files changed, 1344 insertions(+), 437 deletions(-) diff --git a/.github/workflows/rust.yaml b/.github/workflows/rust.yaml index f0aae8f..0a6805d 100644 --- a/.github/workflows/rust.yaml +++ b/.github/workflows/rust.yaml @@ -12,13 +12,16 @@ env: jobs: test: runs-on: ubuntu-latest + strategy: + matrix: + backend: ["embassy", "tokio"] steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: toolchain: stable - name: Test - run: cargo --locked test + run: cargo --locked test --no-default-features --features "${{ matrix.backend }}" fmt: runs-on: ubuntu-latest @@ -33,6 +36,9 @@ jobs: clippy: runs-on: ubuntu-latest + strategy: + matrix: + backend: ["embassy", "tokio"] steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master @@ -40,4 +46,6 @@ jobs: toolchain: stable components: clippy - name: Catch common mistakes - run: cargo --locked clippy -- -D warnings + run: >- + cargo --locked clippy --no-default-features + --features "${{ matrix.backend }}" -- -D warnings diff --git a/CLAUDE.md b/CLAUDE.md index 6ada361..73608c6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -4,7 +4,10 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ## Project Overview -This is a procedural macro crate that provides the `#[controller]` attribute macro for firmware development in `no_std` environments. The macro generates boilerplate for decoupling component interactions through: +This is a procedural macro crate that provides the `#[controller]` attribute macro for +firmware/actor development. By default it targets `no_std` environments using embassy. With the +`tokio` feature, it generates code for `std` environments using tokio. The macro generates +boilerplate for decoupling component interactions through: * A controller struct that manages peripheral state. * Client API for sending commands to the controller. @@ -16,9 +19,12 @@ The macro is applied to a module containing both the controller struct definitio ## Build & Test Commands ```bash -# Run all tests (includes doc tests from README) +# Run all tests with default (embassy) backend cargo test --locked +# Run all tests with tokio backend +cargo test --locked --no-default-features --features tokio + # Run a specific test cargo test --locked @@ -28,8 +34,9 @@ cargo +nightly fmt -- --check # Auto-format code (requires nightly) cargo +nightly fmt -# Run clippy (CI fails on warnings) +# Run clippy for both backends (CI fails on warnings) cargo clippy --locked -- -D warnings +cargo clippy --locked --no-default-features --features tokio -- -D warnings # Build the crate cargo build --locked @@ -40,8 +47,25 @@ cargo doc --locked ## Architecture +### Backend Selection + +The crate has two mutually exclusive features: `embassy` (default) and `tokio`. Code generation +functions use `#[cfg(feature = "...")]` in the proc macro code (not in generated code) to select +which token streams to emit. When `tokio` is enabled: + +* `embassy_sync::channel::Channel` → `tokio::sync::mpsc` + `tokio::sync::oneshot` + (request/response actor pattern) +* `embassy_sync::watch::Watch` → `tokio::sync::watch` (via `std::sync::OnceLock`) +* `embassy_sync::pubsub::PubSubChannel` → `tokio::sync::broadcast` + (via `std::sync::LazyLock`, with `tokio_stream::wrappers::BroadcastStream`) +* Watch subscribers use `tokio_stream::wrappers::WatchStream`. +* `embassy_time::Ticker` → `tokio::time::interval` +* `futures::select_biased!` → `tokio::select! { biased; ... }` +* Static channels use `std::sync::LazyLock` since tokio channels lack const constructors. + ### Macro Entry Point (`src/lib.rs`) -The `controller` attribute macro parses the input as an `ItemMod` (module) and calls `controller::expand_module()`. +The `controller` attribute macro parses the input as an `ItemMod` (module) and calls +`controller::expand_module()`. ### Module Processing (`src/controller/mod.rs`) The `expand_module()` function: @@ -53,15 +77,16 @@ The `expand_module()` function: Channel capacities and subscriber limits are also defined here: * `ALL_CHANNEL_CAPACITY`: 8 (method/getter/setter request channels) -* `SIGNAL_CHANNEL_CAPACITY`: 8 (signal PubSubChannel queue size) -* `BROADCAST_MAX_PUBLISHERS`: 1 (signals only) -* `BROADCAST_MAX_SUBSCRIBERS`: 16 (Watch for published fields, PubSubChannel for signals) +* `SIGNAL_CHANNEL_CAPACITY`: 8 (signal PubSubChannel/broadcast queue size) +* `BROADCAST_MAX_PUBLISHERS`: 1 (signals only, embassy only) +* `BROADCAST_MAX_SUBSCRIBERS`: 16 (Watch for published fields, PubSubChannel for signals, + embassy only) ### Struct Processing (`src/controller/item_struct.rs`) Processes the controller struct definition. Supports three field attributes: **`#[controller(publish)]`** - Enables state change subscriptions: -* Uses `embassy_sync::watch::Watch` channel (stores latest value). +* Uses `embassy_sync::watch::Watch` (or `tokio::sync::watch`) channel (stores latest value). * Generates internal setter (`set_`) that broadcasts changes. * Creates `` subscriber stream type. * Stream yields current value on first poll, then subsequent changes. @@ -82,13 +107,14 @@ initial values to Watch channels so subscribers get them immediately. Processes the controller impl block. Distinguishes between: **Proxied methods** (normal methods): -* Creates request/response channels for each method. +* Creates request/response channels for each method. With tokio, uses `mpsc` + `oneshot` for the + request/response actor pattern. * Generates matching client-side methods that send requests and await responses. * Adds arms to the controller's `run()` method select loop to handle requests. **Signal methods** (marked with `#[controller(signal)]`): * Methods have no body in the user's impl block. -* Uses `embassy_sync::pubsub::PubSubChannel` for broadcast. +* Uses `embassy_sync::pubsub::PubSubChannel` (or `tokio::sync::broadcast`) for broadcast. * Generates method implementation that broadcasts to subscribers. * Creates `` stream type and `Args` struct. * Signal methods are NOT exposed in the client API (controller emits them directly). @@ -102,15 +128,16 @@ Processes the controller impl block. Distinguishes between: * Methods with the same timeout value (same unit and value) are grouped into a single ticker arm. * All methods in a group are called sequentially when the ticker fires (in declaration order). * Poll methods are NOT exposed in the client API (internal to the controller). -* Uses `embassy_time::Ticker::every()` for timing. +* Uses `embassy_time::Ticker::every()` (or `tokio::time::interval()`) for timing. **Getter/setter methods** (from struct field attributes): * Receives getter/setter field info from struct processing. * Generates client-side getter methods that request current field value. * Generates client-side setter methods that update field value (and broadcast if published). -The generated `run()` method contains a `select_biased!` loop that receives method calls from -clients, dispatches them to the user's implementations, and handles periodic poll method calls. +The generated `run()` method contains a `select_biased!` (or `tokio::select! { biased; ... }`) loop +that receives method calls from clients, dispatches them to the user's implementations, and handles +periodic poll method calls. ### Utilities (`src/util.rs`) Case conversion functions (`pascal_to_snake_case`, `snake_to_pascal_case`) used for generating type and method names. @@ -118,18 +145,28 @@ Case conversion functions (`pascal_to_snake_case`, `snake_to_pascal_case`) used ## Dependencies User code must have these dependencies (per README): + +**Default (embassy)**: * `futures` with `async-await` feature. * `embassy-sync` for channels and synchronization. * `embassy-time` for poll method timing (only required if using poll methods). -Dev dependencies include `embassy-executor` and `embassy-time` for testing. +**With `tokio` feature**: +* `futures` with `async-await` feature. +* `tokio` with `sync` feature (and `time` if using poll methods). +* `tokio-stream` with `sync` feature. + +Dev dependencies include `embassy-executor`, `embassy-time`, `tokio`, and `tokio-stream` for +testing. ## Key Limitations * Singleton operation: multiple controller instances interfere with each other. * Methods must be async and cannot use reference parameters/return types. * Maximum 16 subscribers per state/signal stream. -* Published fields must implement `Clone`. +* Published fields must implement `Clone`. With `tokio`, they must also implement `Send + Sync`. +* Signal argument types must implement `Clone`. With `tokio`, they must also implement + `Send + 'static`. * Published field streams yield current value on first poll; intermediate values may be missed if not polled between changes. * Signal streams must be continuously polled or notifications are missed. diff --git a/Cargo.lock b/Cargo.lock index 5c80bb8..04b713a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,6 +23,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + [[package]] name = "cfg-if" version = "1.0.3" @@ -212,6 +218,8 @@ dependencies = [ "proc-macro2", "quote", "syn", + "tokio", + "tokio-stream", ] [[package]] @@ -482,6 +490,52 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tokio" +version = "1.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +dependencies = [ + "pin-project-lite", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-stream" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", + "tokio-util", +] + +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "unicode-ident" version = "1.0.19" diff --git a/Cargo.toml b/Cargo.toml index 1635944..709d1de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "firmware-controller" -description = "Controller to decouple interactions between components in a no_std environment." +description = "Controller (actor) macro to decouple interactions between components, supporting both embassy (no_std) and tokio (std) backends." version = "0.4.2" edition = "2021" authors = [ @@ -13,6 +13,11 @@ repository = "https://github.com/layerx-world/firmware-controller/" [lib] proc-macro = true +[features] +default = ["embassy"] +embassy = [] +tokio = [] + [dependencies] proc-macro2 = "1" quote = "1" @@ -32,3 +37,11 @@ embassy-executor = { version = "0.9.1", features = [ "executor-thread", ] } embassy-time = { version = "0.5.0", features = ["mock-driver"] } +tokio = { version = "1", features = [ + "macros", + "rt", + "sync", + "test-util", + "time", +] } +tokio-stream = { version = "0.1", features = ["sync"] } diff --git a/README.md b/README.md index 2b546da..e8abec5 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ # Firmware Controller This crate provides a macro named `controller` that makes it easy to decouple interactions between -components in a `no_std` environment. +components. It works in both `no_std` (embassy) and `std` (tokio) environments. [Intro](#intro) • [Usage](#usage) • @@ -13,8 +13,9 @@ components in a `no_std` environment. # Intro -This crate provides a macro named `controller` that makes it easy to write controller logic for -firmware. +This crate provides a macro named `controller` that makes it easy to write controller (actor) +logic. By default it targets `no_std` firmware using embassy, but with the `tokio` feature it +generates code for `std` environments using tokio instead. The controller is responsible for control of all the peripherals based on commands it receives from other parts of the code. It also notifies peers about state changes and events via signals. @@ -241,17 +242,49 @@ Key characteristics: entirely by the controller's `run()` loop. * Methods with the same timeout value (same unit and value) are grouped into a single timer arm and called sequentially when the timer fires (in declaration order). -* Uses `embassy_time::Ticker` for timing, which ensures consistent intervals regardless of method - execution time. +* Uses `embassy_time::Ticker` (or `tokio::time::interval` with the `tokio` feature) for timing, + which ensures consistent intervals regardless of method execution time. + +## Backend selection + +The crate provides two mutually exclusive features: `embassy` (default) and `tokio`. Exactly one +must be enabled; enabling both or neither is a compile error. + +By default, the macro generates code targeting embassy for `no_std` firmware. To use tokio instead, +disable the default feature and enable `tokio`: + +```toml +[dependencies] +firmware-controller = { version = "0.4", default-features = false, features = ["tokio"] } +``` + +When the `tokio` feature is enabled: +* Channels use `tokio::sync::mpsc` with `oneshot` for request/response. +* Published fields use `tokio::sync::watch` (via `tokio_stream::wrappers::WatchStream`). +* Signals use `tokio::sync::broadcast` (via `tokio_stream::wrappers::BroadcastStream`). +* Poll methods use `tokio::time::interval`. +* The `run()` loop uses `tokio::select!`. + +**Note:** With the `tokio` feature, signal argument types must additionally implement `Send + +'static` (required by `tokio::sync::broadcast`), and published field types must implement `Send + +Sync` (required by `tokio::sync::watch`). These constraints do not apply to the `embassy` backend. ## Dependencies assumed The `controller` macro assumes that you have the following dependencies in your `Cargo.toml`: +### Default (embassy) + * `futures` with `async-await` feature enabled. * `embassy-sync` * `embassy-time` (only required if using poll methods) +### With `tokio` feature + +* `futures` with `async-await` feature enabled. +* `tokio` with `sync` feature (and `time` if using poll methods) +* `tokio-stream` with `sync` feature + ## Known limitations & Caveats * Currently only works as a singleton: you can create multiple instances of the controller but diff --git a/src/controller/item_impl.rs b/src/controller/item_impl.rs index 279e188..ac7f82b 100644 --- a/src/controller/item_impl.rs +++ b/src/controller/item_impl.rs @@ -74,6 +74,17 @@ pub(crate) fn expand( let pub_getter_client_tx_rx_initializations = pub_getters.iter().map(|g| &g.client_tx_rx_initializations); + // Collect select arms for the select body generation. + let select_arms_vec: Vec<_> = select_arms.clone().collect(); + let setter_arms_vec: Vec<_> = pub_setter_select_arms.clone().collect(); + let getter_arms_vec: Vec<_> = pub_getter_select_arms.clone().collect(); + let select_body = generate_select_body( + &select_arms_vec, + &setter_arms_vec, + &getter_arms_vec, + &poll_select_arms, + ); + let run_method = quote! { pub async fn run(mut self) { #(#args_channels_rx_tx)* @@ -82,12 +93,7 @@ pub(crate) fn expand( #(#poll_ticker_declarations)* loop { - futures::select_biased! { - #(#select_arms,)* - #(#pub_setter_select_arms,)* - #(#pub_getter_select_arms,)* - #(#poll_select_arms,)* - } + #select_body } } }; @@ -163,6 +169,41 @@ pub(crate) fn expand( }) } +#[cfg(feature = "embassy")] +fn generate_select_body( + arms: &[&TokenStream], + setter_arms: &[&TokenStream], + getter_arms: &[&TokenStream], + poll_arms: &[TokenStream], +) -> TokenStream { + quote! { + futures::select_biased! { + #(#arms,)* + #(#setter_arms,)* + #(#getter_arms,)* + #(#poll_arms,)* + } + } +} + +#[cfg(feature = "tokio")] +fn generate_select_body( + arms: &[&TokenStream], + setter_arms: &[&TokenStream], + getter_arms: &[&TokenStream], + poll_arms: &[TokenStream], +) -> TokenStream { + quote! { + tokio::select! { + biased; + #(#arms)* + #(#setter_arms)* + #(#getter_arms)* + #(#poll_arms)* + } + } +} + fn get_methods(input: &mut ItemImpl, struct_name: &Ident) -> Result> { input .items @@ -257,23 +298,8 @@ struct ProxiedMethod { impl ProxiedMethod { fn parse(method: &ImplItemFn, struct_name: &Ident) -> Result { - let method_args = ProxiedMethodArgs::parse(method)?; - - let (args_channel_declarations, input_channel_name, output_channel_name) = - method_args.generate_args_channel_declarations(struct_name); - let (args_channels_rx_tx, select_arm) = - method_args.generate_args_channel_rx_tx(&input_channel_name, &output_channel_name); - let (client_method, client_method_tx_rx_declarations, client_method_tx_rx_initializations) = - method_args.generate_client_method(&input_channel_name, &output_channel_name); - - Ok(Self { - args_channel_declarations, - args_channels_rx_tx, - select_arm, - client_method, - client_method_tx_rx_declarations, - client_method_tx_rx_initializations, - }) + let args = ProxiedMethodArgs::parse(method)?; + Ok(args.generate(struct_name)) } } @@ -284,8 +310,8 @@ struct ProxiedMethodArgs<'a> { out_type: TokenStream, } -impl ProxiedMethodArgs<'_> { - fn parse(method: &ImplItemFn) -> Result> { +impl<'a> ProxiedMethodArgs<'a> { + fn parse(method: &'a ImplItemFn) -> Result { let in_args = MethodInputArgs::parse(method)?; let out_type = match &method.sig.output { syn::ReturnType::Type(_, ty) => quote! { #ty }, @@ -298,18 +324,21 @@ impl ProxiedMethodArgs<'_> { out_type, }) } +} - fn generate_args_channel_declarations( - &self, - struct_name: &Ident, - ) -> (TokenStream, Ident, Ident) { - let in_types = &self.in_args.types; - let out_type = &self.out_type; +#[cfg(feature = "embassy")] +impl ProxiedMethodArgs<'_> { + fn generate(&self, struct_name: &Ident) -> ProxiedMethod { let method_name = &self.method.sig.ident; let method_name_str = method_name.to_string(); + let in_types = &self.in_args.types; + let in_names = &self.in_args.names; + let out_type = &self.out_type; let struct_name_caps = struct_name.to_string().to_uppercase(); let method_name_caps = method_name_str.to_uppercase(); + let capacity = super::ALL_CHANNEL_CAPACITY; + let input_channel_name = Ident::new( &format!("{struct_name_caps}_{method_name_caps}_INPUT_CHANNEL"), self.method.span(), @@ -318,7 +347,7 @@ impl ProxiedMethodArgs<'_> { &format!("{struct_name_caps}_{method_name_caps}_OUTPUT_CHANNEL"), self.method.span(), ); - let capacity = super::ALL_CHANNEL_CAPACITY; + let args_channel_declarations = quote! { static #input_channel_name: embassy_sync::channel::Channel< @@ -334,99 +363,194 @@ impl ProxiedMethodArgs<'_> { > = embassy_sync::channel::Channel::new(); }; - ( - args_channel_declarations, - input_channel_name, - output_channel_name, - ) - } - - // Also generates the select! arm for the method dispatch. - fn generate_args_channel_rx_tx( - &self, - input_channel_name: &Ident, - output_channel_name: &Ident, - ) -> (TokenStream, TokenStream) { - let in_names = &self.in_args.names; - let method_name = &self.method.sig.ident; - let method_name_str = method_name.to_string(); let input_channel_rx_name = Ident::new(&format!("{method_name_str}_rx"), self.method.span()); let output_channel_tx_name = Ident::new(&format!("{method_name_str}_tx"), self.method.span()); let args_channels_rx_tx = quote! { - let #input_channel_rx_name = embassy_sync::channel::Channel::receiver(&#input_channel_name); - let #output_channel_tx_name = embassy_sync::channel::Channel::sender(&#output_channel_name); + let #input_channel_rx_name = + embassy_sync::channel::Channel::receiver( + &#input_channel_name, + ); + let #output_channel_tx_name = + embassy_sync::channel::Channel::sender( + &#output_channel_name, + ); }; + let select_arm = quote! { (#(#in_names),*) = futures::FutureExt::fuse( - embassy_sync::channel::Receiver::receive(&#input_channel_rx_name), + embassy_sync::channel::Receiver::receive( + &#input_channel_rx_name, + ), ) => { - let ret = self.#method_name(#(#in_names),*).await; + let ret = + self.#method_name(#(#in_names),*).await; - embassy_sync::channel::Sender::send(&#output_channel_tx_name, ret).await; + embassy_sync::channel::Sender::send( + &#output_channel_tx_name, + ret, + ).await; } }; - (args_channels_rx_tx, select_arm) - } - - // Also generate the input and output channel declarations, and initializations. - fn generate_client_method( - &self, - input_channel_name: &Ident, - output_channel_name: &Ident, - ) -> (TokenStream, TokenStream, TokenStream) { - let method_name = &self.method.sig.ident; - let in_names = &self.in_args.names; - let in_names = if in_names.is_empty() { + let in_names_tuple = if in_names.is_empty() { quote! { () } } else { quote! { (#(#in_names),*) } }; - let method_name_str = method_name.to_string(); let input_channel_tx_name = Ident::new(&format!("{method_name_str}_tx"), self.method.span()); let output_channel_rx_name = Ident::new(&format!("{method_name_str}_rx"), self.method.span()); - let mut method = self.method.clone(); - - method.block = parse_quote!({ + let mut client_method = self.method.clone(); + client_method.block = parse_quote!({ // Method call. - embassy_sync::channel::Sender::send(&self.#input_channel_tx_name, #in_names).await; + embassy_sync::channel::Sender::send( + &self.#input_channel_tx_name, + #in_names_tuple, + ).await; // Method return. - embassy_sync::channel::Receiver::receive(&self.#output_channel_rx_name).await + embassy_sync::channel::Receiver::receive( + &self.#output_channel_rx_name, + ).await }); + let client_method_tx_rx_declarations = quote! { + #input_channel_tx_name: + embassy_sync::channel::Sender< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + (#(#in_types),*), + #capacity, + >, + #output_channel_rx_name: + embassy_sync::channel::Receiver< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #out_type, + #capacity, + >, + }; + + let client_method_tx_rx_initializations = quote! { + #input_channel_tx_name: + embassy_sync::channel::Channel::sender( + &#input_channel_name, + ), + #output_channel_rx_name: + embassy_sync::channel::Channel::receiver( + &#output_channel_name, + ), + }; + + ProxiedMethod { + args_channel_declarations, + args_channels_rx_tx, + select_arm, + client_method: quote! { #client_method }, + client_method_tx_rx_declarations, + client_method_tx_rx_initializations, + } + } +} + +#[cfg(feature = "tokio")] +impl ProxiedMethodArgs<'_> { + fn generate(&self, struct_name: &Ident) -> ProxiedMethod { + let method_name = &self.method.sig.ident; + let method_name_str = method_name.to_string(); let in_types = &self.in_args.types; + let in_names = &self.in_args.names; let out_type = &self.out_type; + + let struct_name_caps = struct_name.to_string().to_uppercase(); + let method_name_caps = method_name_str.to_uppercase(); let capacity = super::ALL_CHANNEL_CAPACITY; - let tx_rx_declarations = quote! { - #input_channel_tx_name: embassy_sync::channel::Sender< - 'static, - embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, - (#(#in_types),*), - #capacity, - >, - #output_channel_rx_name: embassy_sync::channel::Receiver< - 'static, - embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, - #out_type, - #capacity, - >, + + let channel_name = Ident::new( + &format!("{struct_name_caps}_{method_name_caps}_CHANNEL"), + self.method.span(), + ); + + let args_channel_declarations = quote! { + static #channel_name: std::sync::LazyLock<( + tokio::sync::mpsc::Sender<( + (#(#in_types,)*), + tokio::sync::oneshot::Sender<#out_type>, + )>, + std::sync::Mutex< + Option< + tokio::sync::mpsc::Receiver<( + (#(#in_types,)*), + tokio::sync::oneshot::Sender<#out_type>, + )>, + >, + >, + )> = std::sync::LazyLock::new(|| { + let (tx, rx) = + tokio::sync::mpsc::channel(#capacity); + (tx, std::sync::Mutex::new(Some(rx))) + }); + }; + + let rx_name = Ident::new(&format!("{method_name_str}_rx"), self.method.span()); + let args_channels_rx_tx = quote! { + let mut #rx_name = #channel_name + .1 + .lock() + .unwrap() + .take() + .unwrap(); }; - let tx_rx_initializations = quote! { - #input_channel_tx_name: embassy_sync::channel::Channel::sender(&#input_channel_name), - #output_channel_rx_name: embassy_sync::channel::Channel::receiver(&#output_channel_name), + let select_arm = quote! { + Some(((#(#in_names,)*), __resp_tx)) = + #rx_name.recv() => + { + let ret = + self.#method_name(#(#in_names),*).await; + __resp_tx.send(ret).ok(); + } }; - ( - quote! { #method }, - tx_rx_declarations, - tx_rx_initializations, - ) + let tx_name = Ident::new(&format!("{method_name_str}_tx"), self.method.span()); + let args = if in_names.is_empty() { + quote! { () } + } else { + quote! { (#(#in_names,)*) } + }; + let mut client_method = self.method.clone(); + client_method.block = parse_quote!({ + let (__resp_tx, __resp_rx) = + tokio::sync::oneshot::channel(); + self.#tx_name + .send((#args, __resp_tx)) + .await + .ok(); + __resp_rx.await.unwrap() + }); + + let client_method_tx_rx_declarations = quote! { + #tx_name: tokio::sync::mpsc::Sender<( + (#(#in_types,)*), + tokio::sync::oneshot::Sender<#out_type>, + )>, + }; + + let client_method_tx_rx_initializations = quote! { + #tx_name: #channel_name.0.clone(), + }; + + ProxiedMethod { + args_channel_declarations, + args_channels_rx_tx, + select_arm, + client_method: quote! { #client_method }, + client_method_tx_rx_declarations, + client_method_tx_rx_initializations, + } } } @@ -443,8 +567,19 @@ struct Signal { impl Signal { fn parse(method: &mut ImplItemFn, struct_name: &Ident) -> Result { remove_signal_attr(method)?; + let args = MethodInputArgs::parse(method)?; + Self::generate(method, struct_name, &args) + } +} - let MethodInputArgs { types, names } = MethodInputArgs::parse(method)?; +#[cfg(feature = "embassy")] +impl Signal { + fn generate( + method: &mut ImplItemFn, + struct_name: &Ident, + args: &MethodInputArgs, + ) -> Result { + let MethodInputArgs { types, names } = args; let method_name = &method.sig.ident; let method_name_str = method_name.to_string(); @@ -455,10 +590,6 @@ impl Signal { &format!("{struct_name_caps}_{method_name_caps}_CHANNEL"), method.span(), ); - let signal_publisher_name = Ident::new( - &format!("{struct_name_caps}_{method_name_caps}_PUBLISHER"), - method.span(), - ); let subscriber_struct_name = Ident::new(&format!("{struct_name}{method_name_pascal}"), method.span()); let args_struct_name = Ident::new( @@ -469,6 +600,10 @@ impl Signal { let capacity = super::SIGNAL_CHANNEL_CAPACITY; let max_subscribers = super::BROADCAST_MAX_SUBSCRIBERS; let max_publishers = super::BROADCAST_MAX_PUBLISHERS; + let signal_publisher_name = Ident::new( + &format!("{struct_name_caps}_{method_name_caps}_PUBLISHER"), + method.span(), + ); let declarations = quote! { static #signal_channel_name: @@ -480,15 +615,17 @@ impl Signal { #max_publishers, > = embassy_sync::pubsub::PubSubChannel::new(); - static #signal_publisher_name: embassy_sync::once_lock::OnceLock> = embassy_sync::once_lock::OnceLock::new(); + static #signal_publisher_name: + embassy_sync::once_lock::OnceLock< + embassy_sync::pubsub::publisher::Publisher< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #args_struct_name, + #capacity, + #max_subscribers, + #max_publishers, + >, + > = embassy_sync::once_lock::OnceLock::new(); #[derive(Debug, Clone)] pub struct #args_struct_name { @@ -496,21 +633,24 @@ impl Signal { } pub struct #subscriber_struct_name { - subscriber: embassy_sync::pubsub::subscriber::Subscriber< - 'static, - embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, - #args_struct_name, - #capacity, - #max_subscribers, - #max_publishers, - >, + subscriber: + embassy_sync::pubsub::subscriber::Subscriber< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #args_struct_name, + #capacity, + #max_subscribers, + #max_publishers, + >, } impl #subscriber_struct_name { pub fn new() -> Option { - embassy_sync::pubsub::PubSubChannel::subscriber(&#signal_channel_name) - .ok() - .map(|subscriber| Self { subscriber }) + embassy_sync::pubsub::PubSubChannel::subscriber( + &#signal_channel_name, + ) + .ok() + .map(|subscriber| Self { subscriber }) } } @@ -521,21 +661,131 @@ impl Signal { self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>, ) -> core::task::Poll> { - let subscriber = core::pin::Pin::new(&mut *self.get_mut().subscriber); + let subscriber = core::pin::Pin::new( + &mut *self.get_mut().subscriber, + ); futures::Stream::poll_next(subscriber, cx) } } }; method.block = parse_quote!({ - let publisher = embassy_sync::once_lock::OnceLock::get_or_init( - &#signal_publisher_name, - // Safety: The publisher is only initialized once. - || embassy_sync::pubsub::PubSubChannel::publisher(&#signal_channel_name).unwrap()); + let publisher = + embassy_sync::once_lock::OnceLock::get_or_init( + &#signal_publisher_name, + // Safety: The publisher is only initialized once. + || embassy_sync::pubsub::PubSubChannel::publisher( + &#signal_channel_name, + ) + .unwrap(), + ); embassy_sync::pubsub::publisher::Pub::publish( publisher, #args_struct_name { #(#names),* }, - ).await; + ) + .await; + }); + + let receive_method_name = + Ident::new(&format!("receive_{}", method_name_str), method.span()); + + Ok(Self { + declarations, + receive_method_name, + subscriber_struct_name, + }) + } +} + +#[cfg(feature = "tokio")] +impl Signal { + fn generate( + method: &mut ImplItemFn, + struct_name: &Ident, + args: &MethodInputArgs, + ) -> Result { + let MethodInputArgs { types, names } = args; + + let method_name = &method.sig.ident; + let method_name_str = method_name.to_string(); + let struct_name_caps = struct_name.to_string().to_uppercase(); + let method_name_caps = method_name_str.to_uppercase(); + let method_name_pascal = snake_to_pascal_case(&method_name_str); + let signal_channel_name = Ident::new( + &format!("{struct_name_caps}_{method_name_caps}_CHANNEL"), + method.span(), + ); + let subscriber_struct_name = + Ident::new(&format!("{struct_name}{method_name_pascal}"), method.span()); + let args_struct_name = Ident::new( + &format!("{struct_name}{method_name_pascal}Args"), + method.span(), + ); + + let capacity = super::SIGNAL_CHANNEL_CAPACITY; + + let declarations = quote! { + static #signal_channel_name: + std::sync::LazyLock< + tokio::sync::broadcast::Sender< + #args_struct_name, + >, + > = std::sync::LazyLock::new(|| { + let (tx, _) = + tokio::sync::broadcast::channel( + #capacity, + ); + tx + }); + + #[derive(Debug, Clone)] + pub struct #args_struct_name { + #(pub #names: #types),* + } + + pub struct #subscriber_struct_name { + inner: core::pin::Pin< + Box< + dyn futures::Stream< + Item = #args_struct_name, + > + Send, + >, + >, + } + + impl #subscriber_struct_name { + pub fn new() -> Option { + use futures::StreamExt; + + Some(Self { + inner: Box::pin( + tokio_stream::wrappers::BroadcastStream::new( + #signal_channel_name.subscribe(), + ) + .filter_map(|result| async move { + result.ok() + }), + ), + }) + } + } + + impl futures::Stream for #subscriber_struct_name { + type Item = #args_struct_name; + + fn poll_next( + self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + self.get_mut().inner.as_mut().poll_next(cx) + } + } + }; + + method.block = parse_quote!({ + #signal_channel_name + .send(#args_struct_name { #(#names),* }) + .ok(); }); let receive_method_name = @@ -549,7 +799,8 @@ impl Signal { } } -/// Duration for poll methods. Used as a key for grouping methods with the same timeout. +/// Duration for poll methods. Used as a key for grouping methods with the +/// same timeout. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] enum PollDuration { Seconds(u64), @@ -557,12 +808,36 @@ enum PollDuration { Micros(u64), } +#[cfg(feature = "embassy")] impl PollDuration { fn to_duration_expr(&self) -> TokenStream { match self { - PollDuration::Seconds(n) => quote! { embassy_time::Duration::from_secs(#n) }, - PollDuration::Millis(n) => quote! { embassy_time::Duration::from_millis(#n) }, - PollDuration::Micros(n) => quote! { embassy_time::Duration::from_micros(#n) }, + PollDuration::Seconds(n) => { + quote! { embassy_time::Duration::from_secs(#n) } + } + PollDuration::Millis(n) => { + quote! { embassy_time::Duration::from_millis(#n) } + } + PollDuration::Micros(n) => { + quote! { embassy_time::Duration::from_micros(#n) } + } + } + } +} + +#[cfg(feature = "tokio")] +impl PollDuration { + fn to_duration_expr(&self) -> TokenStream { + match self { + PollDuration::Seconds(n) => { + quote! { std::time::Duration::from_secs(#n) } + } + PollDuration::Millis(n) => { + quote! { std::time::Duration::from_millis(#n) } + } + PollDuration::Micros(n) => { + quote! { std::time::Duration::from_micros(#n) } + } } } } @@ -592,7 +867,8 @@ impl PollMethod { if value == 0 { return Err(syn::Error::new_spanned( lit, - "poll duration must be greater than zero", + "poll duration must be greater \ + than zero", )); } Some((PollDuration::Seconds(value), meta.path.span())) @@ -603,7 +879,8 @@ impl PollMethod { if value == 0 { return Err(syn::Error::new_spanned( lit, - "poll duration must be greater than zero", + "poll duration must be greater \ + than zero", )); } Some((PollDuration::Millis(value), meta.path.span())) @@ -614,7 +891,8 @@ impl PollMethod { if value == 0 { return Err(syn::Error::new_spanned( lit, - "poll duration must be greater than zero", + "poll duration must be greater \ + than zero", )); } Some((PollDuration::Micros(value), meta.path.span())) @@ -626,7 +904,8 @@ impl PollMethod { if duration.is_some() { return Err(syn::Error::new( span, - "only one poll attribute is allowed per method", + "only one poll attribute is allowed \ + per method", )); } duration = Some((new_dur, span)); @@ -649,7 +928,8 @@ impl PollMethod { if has_non_receiver_params { return Err(syn::Error::new_spanned( &method.sig.inputs, - "poll methods cannot have parameters (only `&self` or `&mut self` is allowed)", + "poll methods cannot have parameters (only \ + `&self` or `&mut self` is allowed)", )); } @@ -690,8 +970,10 @@ fn remove_poll_attr(method: &mut ImplItemFn) -> syn::Result<()> { .map(|ident| ident.to_string()) .unwrap_or_else(|| quote!(#path).to_string()); let e = format!( - "poll methods cannot have other `controller` attributes (found `{}`); \ - remove attributes like `getter`, `setter`, `publish`, or `signal`", + "poll methods cannot have other \ + `controller` attributes (found `{}`); \ + remove attributes like `getter`, \ + `setter`, `publish`, or `signal`", found ); Err(syn::Error::new_spanned(meta.path, e)) @@ -707,7 +989,8 @@ fn remove_poll_attr(method: &mut ImplItemFn) -> syn::Result<()> { Ok(()) } -// Like ImplItemFn, but with a semicolon at the end instead of a body block +// Like ImplItemFn, but with a semicolon at the end instead of a body +// block. struct ImplItemSignal { attrs: Vec, vis: Visibility, @@ -777,7 +1060,8 @@ impl MethodInputArgs { _ => { return Some(Err(syn::Error::new( arg.span(), - "Expected identifier as argument name", + "Expected identifier as argument \ + name", ))) } }; @@ -818,6 +1102,7 @@ struct PubGetter { client_tx_rx_initializations: TokenStream, } +#[cfg(feature = "embassy")] fn generate_pub_setter(field: &SetterFieldInfo, struct_name: &Ident) -> PubSetter { let field_name = &field.field_name; let field_type = &field.field_type; @@ -826,6 +1111,8 @@ fn generate_pub_setter(field: &SetterFieldInfo, struct_name: &Ident) -> PubSette let struct_name_caps = struct_name.to_string().to_uppercase(); let field_name_caps = field_name_str.to_uppercase(); + let capacity = super::ALL_CHANNEL_CAPACITY; + let input_channel_name = Ident::new( &format!("{}_SET_{}_INPUT_CHANNEL", struct_name_caps, field_name_caps), field_name.span(), @@ -837,7 +1124,6 @@ fn generate_pub_setter(field: &SetterFieldInfo, struct_name: &Ident) -> PubSette ), field_name.span(), ); - let capacity = super::ALL_CHANNEL_CAPACITY; let channel_declarations = quote! { static #input_channel_name: @@ -859,31 +1145,34 @@ fn generate_pub_setter(field: &SetterFieldInfo, struct_name: &Ident) -> PubSette let output_channel_tx_name = Ident::new(&format!("{}_ack_tx", field_name_str), field_name.span()); let rx_tx = quote! { - let #input_channel_rx_name = embassy_sync::channel::Channel::receiver(&#input_channel_name); - let #output_channel_tx_name = embassy_sync::channel::Channel::sender(&#output_channel_name); + let #input_channel_rx_name = + embassy_sync::channel::Channel::receiver( + &#input_channel_name, + ); + let #output_channel_tx_name = + embassy_sync::channel::Channel::sender( + &#output_channel_name, + ); }; - let select_arm = if let Some(internal_setter) = &field.internal_setter_name { - // Published field: call the internal setter which broadcasts changes. - quote! { - value = futures::FutureExt::fuse( - embassy_sync::channel::Receiver::receive(&#input_channel_rx_name), - ) => { - self.#internal_setter(value).await; - - embassy_sync::channel::Sender::send(&#output_channel_tx_name, ()).await; - } - } + let setter_body = if let Some(internal_setter) = &field.internal_setter_name { + quote! { self.#internal_setter(value).await; } } else { - // Non-published field: set the field directly. - quote! { - value = futures::FutureExt::fuse( - embassy_sync::channel::Receiver::receive(&#input_channel_rx_name), - ) => { - self.#field_name = value; + quote! { self.#field_name = value; } + }; - embassy_sync::channel::Sender::send(&#output_channel_tx_name, ()).await; - } + let select_arm = quote! { + value = futures::FutureExt::fuse( + embassy_sync::channel::Receiver::receive( + &#input_channel_rx_name, + ), + ) => { + #setter_body + + embassy_sync::channel::Sender::send( + &#output_channel_tx_name, + (), + ).await; } }; @@ -892,30 +1181,145 @@ fn generate_pub_setter(field: &SetterFieldInfo, struct_name: &Ident) -> PubSette let output_channel_rx_name = Ident::new(&format!("{}_ack_rx", field_name_str), field_name.span()); let client_method = quote! { - pub async fn #setter_method_name(&self, value: #field_type) { - embassy_sync::channel::Sender::send(&self.#input_channel_tx_name, value).await; - embassy_sync::channel::Receiver::receive(&self.#output_channel_rx_name).await + pub async fn #setter_method_name( + &self, + value: #field_type, + ) { + embassy_sync::channel::Sender::send( + &self.#input_channel_tx_name, + value, + ).await; + embassy_sync::channel::Receiver::receive( + &self.#output_channel_rx_name, + ).await } }; let client_tx_rx_declarations = quote! { - #input_channel_tx_name: embassy_sync::channel::Sender< - 'static, - embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #input_channel_tx_name: + embassy_sync::channel::Sender< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #field_type, + #capacity, + >, + #output_channel_rx_name: + embassy_sync::channel::Receiver< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + (), + #capacity, + >, + }; + + let client_tx_rx_initializations = quote! { + #input_channel_tx_name: + embassy_sync::channel::Channel::sender( + &#input_channel_name, + ), + #output_channel_rx_name: + embassy_sync::channel::Channel::receiver( + &#output_channel_name, + ), + }; + + PubSetter { + channel_declarations, + rx_tx, + select_arm, + client_method, + client_tx_rx_declarations, + client_tx_rx_initializations, + } +} + +#[cfg(feature = "tokio")] +fn generate_pub_setter(field: &SetterFieldInfo, struct_name: &Ident) -> PubSetter { + let field_name = &field.field_name; + let field_type = &field.field_type; + let setter_method_name = &field.setter_name; + let field_name_str = field_name.to_string(); + + let struct_name_caps = struct_name.to_string().to_uppercase(); + let field_name_caps = field_name_str.to_uppercase(); + let capacity = super::ALL_CHANNEL_CAPACITY; + + let channel_name = Ident::new( + &format!("{}_SET_{}_CHANNEL", struct_name_caps, field_name_caps), + field_name.span(), + ); + + let channel_declarations = quote! { + static #channel_name: std::sync::LazyLock<( + tokio::sync::mpsc::Sender<( + #field_type, + tokio::sync::oneshot::Sender<()>, + )>, + std::sync::Mutex< + Option< + tokio::sync::mpsc::Receiver<( + #field_type, + tokio::sync::oneshot::Sender<()>, + )>, + >, + >, + )> = std::sync::LazyLock::new(|| { + let (tx, rx) = + tokio::sync::mpsc::channel(#capacity); + (tx, std::sync::Mutex::new(Some(rx))) + }); + }; + + let rx_name = Ident::new(&format!("{}_set_rx", field_name_str), field_name.span()); + let rx_tx = quote! { + let mut #rx_name = #channel_name + .1 + .lock() + .unwrap() + .take() + .unwrap(); + }; + + let setter_body = if let Some(internal_setter) = &field.internal_setter_name { + quote! { self.#internal_setter(value).await; } + } else { + quote! { self.#field_name = value; } + }; + + let select_arm = quote! { + Some((value, __resp_tx)) = + #rx_name.recv() => + { + #setter_body + __resp_tx.send(()).ok(); + } + }; + + let tx_name = Ident::new(&format!("{}_set_tx", field_name_str), field_name.span()); + let client_method = quote! { + pub async fn #setter_method_name( + &self, + value: #field_type, + ) { + let (__resp_tx, __resp_rx) = + tokio::sync::oneshot::channel(); + self.#tx_name + .send((value, __resp_tx)) + .await + .ok(); + __resp_rx.await.unwrap() + } + }; + + let client_tx_rx_declarations = quote! { + #tx_name: tokio::sync::mpsc::Sender<( #field_type, - #capacity, - >, - #output_channel_rx_name: embassy_sync::channel::Receiver< - 'static, - embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, - (), - #capacity, - >, + tokio::sync::oneshot::Sender<()>, + )>, }; let client_tx_rx_initializations = quote! { - #input_channel_tx_name: embassy_sync::channel::Channel::sender(&#input_channel_name), - #output_channel_rx_name: embassy_sync::channel::Channel::receiver(&#output_channel_name), + #tx_name: #channel_name.0.clone(), }; PubSetter { @@ -928,6 +1332,7 @@ fn generate_pub_setter(field: &SetterFieldInfo, struct_name: &Ident) -> PubSette } } +#[cfg(feature = "embassy")] fn generate_pub_getter(field: &GetterFieldInfo, struct_name: &Ident) -> PubGetter { let field_name = &field.field_name; let field_type = &field.field_type; @@ -936,6 +1341,8 @@ fn generate_pub_getter(field: &GetterFieldInfo, struct_name: &Ident) -> PubGette let struct_name_caps = struct_name.to_string().to_uppercase(); let field_name_caps = field_name_str.to_uppercase(); + let capacity = super::ALL_CHANNEL_CAPACITY; + let input_channel_name = Ident::new( &format!("{}_GET_{}_INPUT_CHANNEL", struct_name_caps, field_name_caps), field_name.span(), @@ -947,7 +1354,6 @@ fn generate_pub_getter(field: &GetterFieldInfo, struct_name: &Ident) -> PubGette ), field_name.span(), ); - let capacity = super::ALL_CHANNEL_CAPACITY; let channel_declarations = quote! { static #input_channel_name: @@ -973,17 +1379,29 @@ fn generate_pub_getter(field: &GetterFieldInfo, struct_name: &Ident) -> PubGette field_name.span(), ); let rx_tx = quote! { - let #input_channel_rx_name = embassy_sync::channel::Channel::receiver(&#input_channel_name); - let #output_channel_tx_name = embassy_sync::channel::Channel::sender(&#output_channel_name); + let #input_channel_rx_name = + embassy_sync::channel::Channel::receiver( + &#input_channel_name, + ); + let #output_channel_tx_name = + embassy_sync::channel::Channel::sender( + &#output_channel_name, + ); }; let select_arm = quote! { _ = futures::FutureExt::fuse( - embassy_sync::channel::Receiver::receive(&#input_channel_rx_name), + embassy_sync::channel::Receiver::receive( + &#input_channel_rx_name, + ), ) => { - let value = core::clone::Clone::clone(&self.#field_name); + let value = + core::clone::Clone::clone(&self.#field_name); - embassy_sync::channel::Sender::send(&#output_channel_tx_name, value).await; + embassy_sync::channel::Sender::send( + &#output_channel_tx_name, + value, + ).await; } }; @@ -997,29 +1415,125 @@ fn generate_pub_getter(field: &GetterFieldInfo, struct_name: &Ident) -> PubGette ); let client_method = quote! { pub async fn #getter_name(&self) -> #field_type { - embassy_sync::channel::Sender::send(&self.#input_channel_tx_name, ()).await; - embassy_sync::channel::Receiver::receive(&self.#output_channel_rx_name).await + embassy_sync::channel::Sender::send( + &self.#input_channel_tx_name, + (), + ).await; + embassy_sync::channel::Receiver::receive( + &self.#output_channel_rx_name, + ).await } }; let client_tx_rx_declarations = quote! { - #input_channel_tx_name: embassy_sync::channel::Sender< - 'static, - embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, - (), - #capacity, - >, - #output_channel_rx_name: embassy_sync::channel::Receiver< - 'static, - embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, - #field_type, - #capacity, + #input_channel_tx_name: + embassy_sync::channel::Sender< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + (), + #capacity, + >, + #output_channel_rx_name: + embassy_sync::channel::Receiver< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #field_type, + #capacity, + >, + }; + + let client_tx_rx_initializations = quote! { + #input_channel_tx_name: + embassy_sync::channel::Channel::sender( + &#input_channel_name, + ), + #output_channel_rx_name: + embassy_sync::channel::Channel::receiver( + &#output_channel_name, + ), + }; + + PubGetter { + channel_declarations, + rx_tx, + select_arm, + client_method, + client_tx_rx_declarations, + client_tx_rx_initializations, + } +} + +#[cfg(feature = "tokio")] +fn generate_pub_getter(field: &GetterFieldInfo, struct_name: &Ident) -> PubGetter { + let field_name = &field.field_name; + let field_type = &field.field_type; + let getter_name = &field.getter_name; + let field_name_str = field_name.to_string(); + + let struct_name_caps = struct_name.to_string().to_uppercase(); + let field_name_caps = field_name_str.to_uppercase(); + let capacity = super::ALL_CHANNEL_CAPACITY; + + let channel_name = Ident::new( + &format!("{}_GET_{}_CHANNEL", struct_name_caps, field_name_caps), + field_name.span(), + ); + + let channel_declarations = quote! { + static #channel_name: std::sync::LazyLock<( + tokio::sync::mpsc::Sender< + tokio::sync::oneshot::Sender<#field_type>, + >, + std::sync::Mutex< + Option< + tokio::sync::mpsc::Receiver< + tokio::sync::oneshot::Sender<#field_type>, + >, + >, + >, + )> = std::sync::LazyLock::new(|| { + let (tx, rx) = + tokio::sync::mpsc::channel(#capacity); + (tx, std::sync::Mutex::new(Some(rx))) + }); + }; + + let rx_name = Ident::new(&format!("{}_get_rx", field_name_str), field_name.span()); + let rx_tx = quote! { + let mut #rx_name = #channel_name + .1 + .lock() + .unwrap() + .take() + .unwrap(); + }; + + let select_arm = quote! { + Some(__resp_tx) = #rx_name.recv() => { + let value = + core::clone::Clone::clone(&self.#field_name); + __resp_tx.send(value).ok(); + } + }; + + let tx_name = Ident::new(&format!("{}_get_tx", field_name_str), field_name.span()); + let client_method = quote! { + pub async fn #getter_name(&self) -> #field_type { + let (__resp_tx, __resp_rx) = + tokio::sync::oneshot::channel(); + self.#tx_name.send(__resp_tx).await.ok(); + __resp_rx.await.unwrap() + } + }; + + let client_tx_rx_declarations = quote! { + #tx_name: tokio::sync::mpsc::Sender< + tokio::sync::oneshot::Sender<#field_type>, >, }; let client_tx_rx_initializations = quote! { - #input_channel_tx_name: embassy_sync::channel::Channel::sender(&#input_channel_name), - #output_channel_rx_name: embassy_sync::channel::Channel::receiver(&#output_channel_name), + #tx_name: #channel_name.0.clone(), }; PubGetter { @@ -1032,11 +1546,51 @@ fn generate_pub_getter(field: &GetterFieldInfo, struct_name: &Ident) -> PubGette } } -/// Generate ticker declarations and select arms for poll methods, grouped by duration. +/// Generate ticker declarations and select arms for poll methods, grouped +/// by duration. /// /// Returns (ticker_declarations, select_arms) where: /// - ticker_declarations: Code to create Tickers before the loop. -/// - select_arms: Select arms that wait on ticker.next(). +/// - select_arms: Select arms that wait on ticker.next()/tick(). +#[cfg(feature = "embassy")] +fn generate_poll_code(poll_methods: &[&PollMethod]) -> (Vec, Vec) { + // Group poll methods by duration. + let mut groups: BTreeMap> = BTreeMap::new(); + for poll in poll_methods { + groups + .entry(poll.duration.clone()) + .or_default() + .push(&poll.method_name); + } + + let mut ticker_declarations = Vec::new(); + let mut select_arms = Vec::new(); + + for (index, (duration, method_names)) in groups.into_iter().enumerate() { + let duration_expr = duration.to_duration_expr(); + let ticker_name = Ident::new( + &format!("__poll_ticker_{index}"), + proc_macro2::Span::call_site(), + ); + + ticker_declarations.push(quote! { + let mut #ticker_name = + embassy_time::Ticker::every(#duration_expr); + }); + + select_arms.push(quote! { + _ = futures::FutureExt::fuse( + #ticker_name.next(), + ) => { + #(self.#method_names().await;)* + } + }); + } + + (ticker_declarations, select_arms) +} + +#[cfg(feature = "tokio")] fn generate_poll_code(poll_methods: &[&PollMethod]) -> (Vec, Vec) { // Group poll methods by duration. let mut groups: BTreeMap> = BTreeMap::new(); @@ -1058,11 +1612,14 @@ fn generate_poll_code(poll_methods: &[&PollMethod]) -> (Vec, Vec { + _ = #ticker_name.tick() => { #(self.#method_names().await;)* } }); diff --git a/src/controller/item_struct.rs b/src/controller/item_struct.rs index 9024be0..fd0ce57 100644 --- a/src/controller/item_struct.rs +++ b/src/controller/item_struct.rs @@ -50,35 +50,46 @@ pub(crate) fn expand(mut input: ItemStruct) -> Result { sender_fields_initializations, setters, subscriber_declarations, + initial_value_sends, published_fields_info, ) = struct_fields.published().fold( - (quote!(), quote!(), quote!(), quote!(), quote!(), Vec::new()), + ( + quote!(), + quote!(), + quote!(), + quote!(), + quote!(), + Vec::new(), + Vec::new(), + ), |( watch_channels, sender_fields_declarations, sender_fields_initializations, setters, subscribers, + mut init_sends, mut infos, ), f| { let published = f.published.as_ref().unwrap(); - let (watch_channel, sender_field, sender_field_init, setter, subscriber) = ( - &published.watch_channel_declaration, - &published.sender_field_declaration, - &published.sender_field_initialization, - &published.setter, - &published.subscriber_declaration, - ); infos.push(published.info.clone()); + init_sends.push(&published.initial_value_send); + + let watch_channel = &published.watch_channel_declaration; + let sender_field = &published.sender_field_declaration; + let sender_field_init = &published.sender_field_initialization; + let setter = &published.setter; + let subscriber = &published.subscriber_declaration; ( quote! { #watch_channels #watch_channel }, - quote! { #sender_fields_declarations #sender_field, }, - quote! { #sender_fields_initializations #sender_field_init, }, + quote! { #sender_fields_declarations #sender_field }, + quote! { #sender_fields_initializations #sender_field_init }, quote! { #setters #setter }, quote! { #subscribers #subscriber }, + init_sends, infos, ) }, @@ -137,14 +148,7 @@ pub(crate) fn expand(mut input: ItemStruct) -> Result { }); let vis = &input.vis; - // Generate initial value sends for Watch channels. - let initial_value_sends = published_fields_info.iter().map(|info| { - let field_name = &info.field_name; - let sender_name = Ident::new(&format!("{}_sender", field_name), field_name.span()); - quote! { - __self.#sender_name.send(core::clone::Clone::clone(&__self.#field_name)); - } - }); + // Initial value sends are already collected from PublishedFieldCode. Ok(ExpandedStruct { tokens: quote! { @@ -284,6 +288,8 @@ struct PublishedFieldCode { watch_channel_declaration: proc_macro2::TokenStream, /// Subscriber struct declaration. subscriber_declaration: proc_macro2::TokenStream, + /// Code to send initial value in `new()`. + initial_value_send: proc_macro2::TokenStream, /// Information to be passed to impl processing. info: PublishedFieldInfo, } @@ -344,27 +350,21 @@ fn parse_controller_attrs(field: &mut Field) -> Result { /// Generate code for a published field using Watch channel. fn generate_publish_code(field: &Field, struct_name: &Ident) -> Result { - let struct_name_str = struct_name.to_string(); - let field_name = field.ident.as_ref().unwrap(); - let field_name_str = field_name.to_string(); - let ty = &field.ty; - - let struct_name_caps = pascal_to_snake_case(&struct_name_str).to_ascii_uppercase(); - let field_name_caps = field_name_str.to_ascii_uppercase(); - let watch_channel_name = Ident::new( - &format!("{struct_name_caps}_{field_name_caps}_WATCH"), - field.span(), - ); - - let field_name_pascal = snake_to_pascal_case(&field_name_str); - let subscriber_struct_name = Ident::new( - &format!("{struct_name_str}{field_name_pascal}"), - field.span(), - ); - let max_subscribers = super::BROADCAST_MAX_SUBSCRIBERS; + let names = PublishNames::new(field, struct_name); + generate_publish_code_impl(&names) +} - let setter_name = Ident::new(&format!("set_{field_name_str}"), field.span()); - let sender_name = Ident::new(&format!("{field_name_str}_sender"), field.span()); +#[cfg(feature = "embassy")] +fn generate_publish_code_impl(n: &PublishNames) -> Result { + let PublishNames { + field_name, + ty, + watch_channel_name, + subscriber_struct_name, + setter_name, + sender_name, + max_subscribers, + } = n; let sender_field_declaration = quote! { #sender_name: @@ -373,18 +373,20 @@ fn generate_publish_code(field: &Field, struct_name: &Ident) -> Result + >, }; let sender_field_initialization = quote! { - #sender_name: embassy_sync::watch::Watch::sender(&#watch_channel_name) + #sender_name: embassy_sync::watch::Watch::sender(&#watch_channel_name), }; // Watch send() is sync, but we keep the setter async for API compatibility. let setter = quote! { pub async fn #setter_name(&mut self, value: #ty) { self.#field_name = value; - self.#sender_name.send(core::clone::Clone::clone(&self.#field_name)); + self.#sender_name.send( + core::clone::Clone::clone(&self.#field_name), + ); } }; @@ -410,11 +412,13 @@ fn generate_publish_code(field: &Field, struct_name: &Ident) -> Result Option { - embassy_sync::watch::Watch::receiver(&#watch_channel_name) - .map(|receiver| Self { - receiver, - first_poll: true, - }) + embassy_sync::watch::Watch::receiver( + &#watch_channel_name, + ) + .map(|receiver| Self { + receiver, + first_poll: true, + }) } } @@ -445,9 +449,17 @@ fn generate_publish_code(field: &Field, struct_name: &Ident) -> Result Result Result { + let PublishNames { + field_name, + ty, + watch_channel_name, + subscriber_struct_name, + setter_name, + .. + } = n; + + let setter = quote! { + pub async fn #setter_name(&mut self, value: #ty) { + self.#field_name = value; + #watch_channel_name + .get() + .unwrap() + .send(core::clone::Clone::clone(&self.#field_name)) + .ok(); + } + }; + + let watch_channel_declaration = quote! { + static #watch_channel_name: + std::sync::OnceLock> + = std::sync::OnceLock::new(); + }; + + let subscriber_declaration = quote! { + pub struct #subscriber_struct_name { + inner: tokio_stream::wrappers::WatchStream<#ty>, + } + + impl #subscriber_struct_name { + pub fn new() -> Option { + #watch_channel_name.get().map(|sender| Self { + inner: tokio_stream::wrappers::WatchStream::new( + sender.subscribe(), + ), + }) + } + } + + impl futures::Stream for #subscriber_struct_name { + type Item = #ty; + + fn poll_next( + self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + let this = self.get_mut(); + futures::Stream::poll_next( + core::pin::Pin::new(&mut this.inner), + cx, + ) + } + } + }; + + let initial_value_send = quote! { + let (__tx, _) = tokio::sync::watch::channel( + core::clone::Clone::clone(&__self.#field_name), + ); + #watch_channel_name.set(__tx).ok(); + }; + + let info = PublishedFieldInfo { + field_name: field_name.clone(), + subscriber_struct_name: subscriber_struct_name.clone(), + }; + + Ok(PublishedFieldCode { + sender_field_declaration: quote! {}, + sender_field_initialization: quote! {}, + setter, + watch_channel_declaration, + subscriber_declaration, + initial_value_send, info, }) } + +/// Common name generation for published fields. +struct PublishNames { + field_name: Ident, + ty: syn::Type, + watch_channel_name: Ident, + subscriber_struct_name: Ident, + setter_name: Ident, + #[cfg(feature = "embassy")] + sender_name: Ident, + #[cfg(feature = "embassy")] + max_subscribers: usize, +} + +impl PublishNames { + fn new(field: &Field, struct_name: &Ident) -> Self { + let struct_name_str = struct_name.to_string(); + let field_name = field.ident.as_ref().unwrap().clone(); + let field_name_str = field_name.to_string(); + let ty = field.ty.clone(); + + let struct_name_caps = pascal_to_snake_case(&struct_name_str).to_ascii_uppercase(); + let field_name_caps = field_name_str.to_ascii_uppercase(); + let watch_channel_name = Ident::new( + &format!("{struct_name_caps}_{field_name_caps}_WATCH"), + field.span(), + ); + + let field_name_pascal = snake_to_pascal_case(&field_name_str); + let subscriber_struct_name = Ident::new( + &format!("{struct_name_str}{field_name_pascal}"), + field.span(), + ); + + let setter_name = Ident::new(&format!("set_{field_name_str}"), field.span()); + + Self { + field_name, + ty, + watch_channel_name, + subscriber_struct_name, + setter_name, + #[cfg(feature = "embassy")] + sender_name: Ident::new(&format!("{field_name_str}_sender"), field.span()), + #[cfg(feature = "embassy")] + max_subscribers: super::BROADCAST_MAX_SUBSCRIBERS, + } + } +} diff --git a/src/controller/mod.rs b/src/controller/mod.rs index f2f076c..e03ef93 100644 --- a/src/controller/mod.rs +++ b/src/controller/mod.rs @@ -7,7 +7,9 @@ use syn::{spanned::Spanned, Item, ItemMod, Result}; const ALL_CHANNEL_CAPACITY: usize = 8; const SIGNAL_CHANNEL_CAPACITY: usize = 8; +#[cfg(feature = "embassy")] const BROADCAST_MAX_PUBLISHERS: usize = 1; +#[cfg(feature = "embassy")] const BROADCAST_MAX_SUBSCRIBERS: usize = 16; pub(crate) fn expand_module(input: ItemMod) -> Result { diff --git a/src/lib.rs b/src/lib.rs index e897e92..4f54590 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,11 @@ #![doc = include_str!("../README.md")] +#[cfg(not(any(feature = "embassy", feature = "tokio")))] +compile_error!("Either the `embassy` or `tokio` feature must be enabled"); + +#[cfg(all(feature = "embassy", feature = "tokio"))] +compile_error!("The `embassy` and `tokio` features are mutually exclusive"); + use proc_macro::TokenStream; use syn::{parse_macro_input, punctuated::Punctuated, ItemMod, Meta, Token}; diff --git a/tests/integration.rs b/tests/integration.rs index c5cd062..1ced2b6 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -71,155 +71,168 @@ mod test_controller { use test_controller::*; +#[cfg(feature = "embassy")] #[test] fn test_controller_basic_functionality() { - // Create the controller before spawning the thread to avoid any race conditions. - // The channels used for communication will buffer requests, so it's safe for the - // client to start making calls even if the controller task hasn't fully started yet. let controller = Controller::new(State::Idle, Mode::Normal, 0); - // Run the controller in a background thread. std::thread::spawn(move || { let executor = Box::leak(Box::new(embassy_executor::Executor::new())); executor.run(move |spawner| { - spawner.spawn(controller_task(controller)).unwrap(); + spawner.spawn(controller_embassy_task(controller)).unwrap(); }); }); - // Run the test logic. futures::executor::block_on(async { - // Create client. - let mut client = ControllerClient::new(); - - // Test 1: Subscribe to state changes. - let mut state_stream = client.receive_state_changed().expect("Failed to subscribe"); - - // Test 1a: First poll returns the initial (current) value. - let initial_state = state_stream - .next() - .await - .expect("Should receive initial state"); - assert_eq!(initial_state, State::Idle, "Initial state should be Idle"); - - // Test 2: Subscribe to signals. - let mut error_stream = client - .receive_error_occurred() - .expect("Failed to subscribe to error"); - let mut complete_stream = client - .receive_operation_complete() - .expect("Failed to subscribe to complete"); - - // Test 3: Call a method and verify return value. - let counter = client.get_counter().await; - assert_eq!(counter, 0, "Initial counter should be 0"); - - // Test 4: Call increment and verify it increases. - let counter = client.increment().await; - assert_eq!(counter, 1, "Counter should be 1 after increment"); - - let counter = client.increment().await; - assert_eq!(counter, 2, "Counter should be 2 after second increment"); - - // Test 5: Call method that changes state and emits signal. - let activate_result = client.activate().await; - assert!( - activate_result.is_ok(), - "Activate should succeed from Idle state" - ); - - // Verify we received the state change (raw value, not Changed struct). - let new_state = state_stream - .next() - .await - .expect("Should receive state change"); - assert_eq!(new_state, State::Active, "New state should be Active"); - - // Verify we received the operation_complete signal. - let _complete = complete_stream - .next() - .await - .expect("Should receive operation complete signal"); - - // Test 6: Call method that returns error. - let error_result = client.trigger_error().await; - assert!( - error_result.is_err(), - "trigger_error should return an error" - ); - assert_eq!( - error_result.unwrap_err(), - TestError::OperationFailed, - "Should return OperationFailed error" - ); - - // Verify state changed to Error. - let new_state = state_stream - .next() - .await - .expect("Should receive state change"); - assert_eq!(new_state, State::Error, "New state should be Error"); - - // Verify we received the error signal. - let error_signal = error_stream - .next() - .await - .expect("Should receive error signal"); - assert_eq!(error_signal.code, 42, "Error code should be 42"); - assert_eq!( - error_signal.message.as_str(), - "Test error", - "Error message should match" - ); - - // Test 7: Try to activate again (should fail due to invalid state). - let activate_result = client.activate().await; - assert!( - activate_result.is_err(), - "Activate should fail from Error state" - ); - assert_eq!( - activate_result.unwrap_err(), - TestError::InvalidState, - "Should return InvalidState error" - ); - - // Test 8: Use setter to change mode. - client.set_mode(Mode::Debug).await; - - // Test 9: Call method with no return value. - client.return_nothing().await; - - // Test 10: Use getter with custom name to get state. - let state = client.get_current_state().await; - assert_eq!(state, State::Error, "State should be Error"); - - // Test 11: Use getter with default field name to get mode. - let mode = client.mode().await; - assert_eq!(mode, Mode::Debug, "Mode should be Debug"); - - // Test 12: Use setter with custom name (new syntax). - client.change_state(State::Idle).await; - let state = client.get_current_state().await; - assert_eq!( - state, - State::Idle, - "State should be Idle after change_state" - ); - - // Test 13: Use setter without publish (independent setter). - client.set_counter(100).await; - let counter = client.get_counter().await; - assert_eq!(counter, 100, "Counter should be 100 after set_counter"); - - // If we get here, all tests passed. + run_basic_test().await; }); } -#[embassy_executor::task] +#[cfg(feature = "tokio")] +#[tokio::test] +async fn test_controller_basic_functionality() { + let controller = Controller::new(State::Idle, Mode::Normal, 0); + tokio::spawn(controller_task(controller)); + tokio::task::yield_now().await; + + run_basic_test().await; +} + +async fn run_basic_test() { + let mut client = ControllerClient::new(); + + // Test 1: Subscribe to state changes. + let mut state_stream = client.receive_state_changed().expect("Failed to subscribe"); + + // Test 1a: First poll returns the initial (current) value. + let initial_state = state_stream + .next() + .await + .expect("Should receive initial state"); + assert_eq!(initial_state, State::Idle, "Initial state should be Idle"); + + // Test 2: Subscribe to signals. + let mut error_stream = client + .receive_error_occurred() + .expect("Failed to subscribe to error"); + let mut complete_stream = client + .receive_operation_complete() + .expect("Failed to subscribe to complete"); + + // Test 3: Call a method and verify return value. + let counter = client.get_counter().await; + assert_eq!(counter, 0, "Initial counter should be 0"); + + // Test 4: Call increment and verify it increases. + let counter = client.increment().await; + assert_eq!(counter, 1, "Counter should be 1 after increment"); + + let counter = client.increment().await; + assert_eq!(counter, 2, "Counter should be 2 after second increment"); + + // Test 5: Call method that changes state and emits signal. + let activate_result = client.activate().await; + assert!( + activate_result.is_ok(), + "Activate should succeed from Idle state" + ); + + // Verify we received the state change. + let new_state = state_stream + .next() + .await + .expect("Should receive state change"); + assert_eq!(new_state, State::Active, "New state should be Active"); + + // Verify we received the operation_complete signal. + let _complete = complete_stream + .next() + .await + .expect("Should receive operation complete signal"); + + // Test 6: Call method that returns error. + let error_result = client.trigger_error().await; + assert!( + error_result.is_err(), + "trigger_error should return an error" + ); + assert_eq!( + error_result.unwrap_err(), + TestError::OperationFailed, + "Should return OperationFailed error" + ); + + // Verify state changed to Error. + let new_state = state_stream + .next() + .await + .expect("Should receive state change"); + assert_eq!(new_state, State::Error, "New state should be Error"); + + // Verify we received the error signal. + let error_signal = error_stream + .next() + .await + .expect("Should receive error signal"); + assert_eq!(error_signal.code, 42, "Error code should be 42"); + assert_eq!( + error_signal.message.as_str(), + "Test error", + "Error message should match" + ); + + // Test 7: Try to activate again (should fail due to invalid state). + let activate_result = client.activate().await; + assert!( + activate_result.is_err(), + "Activate should fail from Error state" + ); + assert_eq!( + activate_result.unwrap_err(), + TestError::InvalidState, + "Should return InvalidState error" + ); + + // Test 8: Use setter to change mode. + client.set_mode(Mode::Debug).await; + + // Test 9: Call method with no return value. + client.return_nothing().await; + + // Test 10: Use getter with custom name to get state. + let state = client.get_current_state().await; + assert_eq!(state, State::Error, "State should be Error"); + + // Test 11: Use getter with default field name to get mode. + let mode = client.mode().await; + assert_eq!(mode, Mode::Debug, "Mode should be Debug"); + + // Test 12: Use setter with custom name (new syntax). + client.change_state(State::Idle).await; + let state = client.get_current_state().await; + assert_eq!( + state, + State::Idle, + "State should be Idle after change_state" + ); + + // Test 13: Use setter without publish (independent setter). + client.set_counter(100).await; + let counter = client.get_counter().await; + assert_eq!(counter, 100, "Counter should be 100 after set_counter"); +} + +#[cfg(feature = "tokio")] async fn controller_task(controller: Controller) { controller.run().await; } +#[cfg(feature = "embassy")] +#[embassy_executor::task] +async fn controller_embassy_task(controller: Controller) { + controller.run().await; +} + /// Test that visibility specifiers on struct fields are preserved. #[controller] mod visibility_test_controller { @@ -235,21 +248,13 @@ mod visibility_test_controller { impl Controller {} } +#[cfg(feature = "embassy")] #[test] fn test_visibility_on_fields() { - // Verify struct compiles and fields have correct visibility. let controller = visibility_test_controller::Controller::new(42, -1, true); - - // Public field should be accessible. assert_eq!(controller.public_field, 42); - - // pub(crate) field should be accessible within this crate. assert_eq!(controller.crate_field, -1); - // Note: private_field is not accessible here, which is correct. - // We can only access it through the method. - - // Run the controller in a background thread. std::thread::spawn(move || { let executor = Box::leak(Box::new(embassy_executor::Executor::new())); executor.run(move |spawner| { @@ -260,15 +265,31 @@ fn test_visibility_on_fields() { }); futures::executor::block_on(async { - let client = visibility_test_controller::ControllerClient::new(); - - // Use generated getters from #[controller(getter)] attribute. - assert_eq!(client.public_field().await, 42); - assert_eq!(client.crate_field().await, -1); - assert_eq!(client.private_field().await, true); + run_visibility_test().await; }); } +#[cfg(feature = "tokio")] +#[tokio::test] +async fn test_visibility_on_fields() { + let controller = visibility_test_controller::Controller::new(42, -1, true); + assert_eq!(controller.public_field, 42); + assert_eq!(controller.crate_field, -1); + + tokio::spawn(async move { controller.run().await }); + tokio::task::yield_now().await; + + run_visibility_test().await; +} + +async fn run_visibility_test() { + let client = visibility_test_controller::ControllerClient::new(); + assert_eq!(client.public_field().await, 42); + assert_eq!(client.crate_field().await, -1); + assert_eq!(client.private_field().await, true); +} + +#[cfg(feature = "embassy")] #[embassy_executor::task] async fn visibility_controller_task(controller: visibility_test_controller::Controller) { controller.run().await; @@ -311,11 +332,11 @@ mod poll_test_controller { } /// Test that poll methods are called at the expected intervals. +#[cfg(feature = "embassy")] #[test] fn poll_methods() { use embassy_time::{Duration, MockDriver}; - // Reset mock driver and counters. let driver = MockDriver::get(); driver.reset(); POLL_A_COUNT.store(0, Ordering::SeqCst); @@ -323,11 +344,8 @@ fn poll_methods() { POLL_C_COUNT.store(0, Ordering::SeqCst); let controller = poll_test_controller::Controller::new(42); - - // Verify struct fields are accessible. assert_eq!(controller.value, 42); - // Run the controller in a background thread. std::thread::spawn(move || { let executor = Box::leak(Box::new(embassy_executor::Executor::new())); executor.run(move |spawner| { @@ -338,28 +356,63 @@ fn poll_methods() { // Give the executor a moment to start. std::thread::sleep(std::time::Duration::from_millis(10)); - // Advance mock time by 50ms - poll_a and poll_b should fire once. - driver.advance(Duration::from_millis(50)); - std::thread::sleep(std::time::Duration::from_millis(10)); - assert_eq!(POLL_A_COUNT.load(Ordering::SeqCst), 1); - assert_eq!(POLL_B_COUNT.load(Ordering::SeqCst), 1); - assert_eq!(POLL_C_COUNT.load(Ordering::SeqCst), 0); + futures::executor::block_on(run_poll_test(|millis| async move { + driver.advance(Duration::from_millis(millis)); + std::thread::sleep(std::time::Duration::from_millis(10)); + })); +} - // Advance another 50ms (total 100ms) - poll_a/poll_b fire again, poll_c fires once. - driver.advance(Duration::from_millis(50)); - std::thread::sleep(std::time::Duration::from_millis(10)); - assert_eq!(POLL_A_COUNT.load(Ordering::SeqCst), 2); - assert_eq!(POLL_B_COUNT.load(Ordering::SeqCst), 2); - assert_eq!(POLL_C_COUNT.load(Ordering::SeqCst), 1); +#[cfg(feature = "tokio")] +#[tokio::test(start_paused = true)] +async fn poll_methods() { + POLL_A_COUNT.store(0, Ordering::SeqCst); + POLL_B_COUNT.store(0, Ordering::SeqCst); + POLL_C_COUNT.store(0, Ordering::SeqCst); - // Advance another 100ms (total 200ms) - poll_a/poll_b fire 2 more times, poll_c fires once. - driver.advance(Duration::from_millis(100)); - std::thread::sleep(std::time::Duration::from_millis(10)); - assert_eq!(POLL_A_COUNT.load(Ordering::SeqCst), 4); - assert_eq!(POLL_B_COUNT.load(Ordering::SeqCst), 4); - assert_eq!(POLL_C_COUNT.load(Ordering::SeqCst), 2); + let controller = poll_test_controller::Controller::new(42); + assert_eq!(controller.value, 42); + + tokio::spawn(async move { controller.run().await }); + + // Yield to let the controller task start and skip the initial ticks. + for _ in 0..10 { + tokio::task::yield_now().await; + } + + run_poll_test(|millis| async move { + tokio::time::advance(std::time::Duration::from_millis(millis)).await; + for _ in 0..10 { + tokio::task::yield_now().await; + } + }) + .await; +} + +async fn run_poll_test(advance_and_settle: F) +where + F: Fn(u64) -> Fut, + Fut: core::future::Future, +{ + // Advance 50ms - poll_a and poll_b should fire once. + advance_and_settle(50).await; + assert_poll_counts(1, 1, 0); + + // Advance another 50ms (total 100ms). + advance_and_settle(50).await; + assert_poll_counts(2, 2, 1); + + // Advance another 100ms (total 200ms). + advance_and_settle(100).await; + assert_poll_counts(4, 4, 2); +} + +fn assert_poll_counts(a: u32, b: u32, c: u32) { + assert_eq!(POLL_A_COUNT.load(Ordering::SeqCst), a); + assert_eq!(POLL_B_COUNT.load(Ordering::SeqCst), b); + assert_eq!(POLL_C_COUNT.load(Ordering::SeqCst), c); } +#[cfg(feature = "embassy")] #[embassy_executor::task] async fn poll_controller_task(controller: poll_test_controller::Controller) { controller.run().await; From 8c361138a39c65bc530f0ecf8e32a202882852aa Mon Sep 17 00:00:00 2001 From: Zeeshan Ali Khan Date: Sat, 21 Mar 2026 22:41:14 +0100 Subject: [PATCH 2/2] =?UTF-8?q?=F0=9F=94=96=20Bump=20semver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We just added a default feature, which is a breaking change. Besides, the next release will have some major changes anyway. --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 04b713a..5d1838f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -207,7 +207,7 @@ dependencies = [ [[package]] name = "firmware-controller" -version = "0.4.2" +version = "0.5.0" dependencies = [ "critical-section", "embassy-executor", diff --git a/Cargo.toml b/Cargo.toml index 709d1de..5bcbbb6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "firmware-controller" description = "Controller (actor) macro to decouple interactions between components, supporting both embassy (no_std) and tokio (std) backends." -version = "0.4.2" +version = "0.5.0" edition = "2021" authors = [ "Zeeshan Ali Khan ",