Skip to content

Commit 4ef5e84

Browse files
committed
refactor channel acquisition/creation
1 parent 9801a57 commit 4ef5e84

1 file changed

Lines changed: 42 additions & 36 deletions

File tree

src/channel.rs

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -116,37 +116,40 @@ async fn broadcast_to_channel(
116116
State(state): State<Arc<AppState>>,
117117
body: Body,
118118
) -> axum::response::Result<()> {
119-
let mut channel_clients = state.channel_clients.lock().await;
119+
let channel_tx = {
120+
// dropped at the end of this scope
121+
let mut channel_clients = state.channel_clients.lock().await;
120122

121-
let namespace_channels = match channel_clients.entry(namespace) {
122-
std::collections::hash_map::Entry::Occupied(e) => e.into_mut(),
123-
std::collections::hash_map::Entry::Vacant(e) => e.insert(HashMap::new()),
124-
};
123+
let channels = match channel_clients.entry(namespace) {
124+
std::collections::hash_map::Entry::Occupied(e) => e.into_mut(),
125+
std::collections::hash_map::Entry::Vacant(e) => e.insert(HashMap::new()),
126+
};
125127

126-
let tx = if let Some((tx, _rx)) = namespace_channels.get(&channel_name) {
127-
tx.clone()
128-
} else {
129-
let (tx, rx) = flume::bounded(0);
128+
if let Some((tx, _rx)) = channels.get(&channel_name) {
129+
tx.clone()
130+
} else {
131+
let (tx, rx) = flume::bounded(0);
130132

131-
namespace_channels.insert(channel_name, (tx.clone(), rx));
133+
channels.insert(channel_name, (tx.clone(), rx));
132134

133-
tx
135+
tx
136+
}
134137
};
135138

136-
drop(channel_clients);
137-
138139
let body_stream = body.into_data_stream();
139140

140141
let (drop_guard, drop_guard_rx) = DropGuard::new();
141142

142-
tx.send_async(Payload {
143-
body_stream,
144-
headers,
145-
drop_guard,
146-
})
147-
.await
148-
.map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
143+
channel_tx
144+
.send_async(Payload {
145+
body_stream,
146+
headers,
147+
drop_guard,
148+
})
149+
.await
150+
.map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
149151

152+
// wait for the drop guard to finish before we complete this http request
150153
drop_guard_rx
151154
.await
152155
.map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
@@ -158,32 +161,35 @@ async fn subscribe_to_channel(
158161
Path((namespace, channel_name)): Path<(Namespace, ChannelName)>,
159162
State(state): State<Arc<AppState>>,
160163
) -> axum::response::Result<impl IntoResponse> {
161-
let mut channel_clients = state.channel_clients.lock().await;
162-
163-
let namespace_channels = match channel_clients.entry(namespace) {
164-
std::collections::hash_map::Entry::Occupied(e) => e.into_mut(),
165-
std::collections::hash_map::Entry::Vacant(e) => e.insert(HashMap::new()),
166-
};
164+
let channel_rx = {
165+
// dropped at the end of this scope
166+
let mut channel_clients = state.channel_clients.lock().await;
167167

168-
let rx = if let Some((_tx, rx)) = namespace_channels.get(&channel_name) {
169-
rx.clone()
170-
} else {
171-
let (tx, rx) = flume::bounded(0);
168+
let channels = match channel_clients.entry(namespace) {
169+
std::collections::hash_map::Entry::Occupied(e) => e.into_mut(),
170+
std::collections::hash_map::Entry::Vacant(e) => e.insert(HashMap::new()),
171+
};
172172

173-
namespace_channels.insert(channel_name, (tx, rx.clone()));
173+
let rx = if let Some((_tx, rx)) = channels.get(&channel_name) {
174+
rx.clone()
175+
} else {
176+
let (tx, rx) = flume::bounded(0);
174177

175-
rx
176-
};
178+
channels.insert(channel_name, (tx, rx.clone()));
177179

178-
drop(channel_clients);
180+
rx
181+
};
179182

180-
let rx = rx.into_recv_async();
183+
rx.into_recv_async()
184+
};
181185

182186
let Payload {
183187
body_stream,
184188
headers: producer_request_headers,
185189
drop_guard: _drop_guard,
186-
} = rx.await.map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
190+
} = channel_rx
191+
.await
192+
.map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
187193

188194
let body = Body::from_stream(body_stream);
189195

0 commit comments

Comments
 (0)