Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2327,6 +2327,13 @@ impl CoroutineKind {
matches!(self, CoroutineKind::Desugared(_, CoroutineSource::Fn))
}

pub fn is_async_desugaring(self) -> bool {
matches!(
self,
CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
)
}

pub fn to_plural_string(&self) -> String {
match self {
CoroutineKind::Desugared(d, CoroutineSource::Fn) => format!("{d:#}fn bodies"),
Expand Down
200 changes: 94 additions & 106 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,19 +563,15 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
);
}

/// Transforms the `body` of the coroutine applying the following transforms:
///
/// - Eliminates all the `get_context` calls that async lowering created.
/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
///
/// The `Local`s that have their types replaced are:
/// - The `resume` argument itself.
/// - The argument to `get_context`.
/// - The yielded value of a `yield`.
///
/// Async desugaring uses an unsafe binder type `ResumeTy` to circumvert borrow-checking.
/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
///
/// The actual should be `&mut Context<'_>`. This performs the substitution:
/// - create a new local `_r` of type `ResumeTy`;
/// - assign `ResumeTy(transmute::<&mut Context<'_>, NonNull<Context<'_>>>(_2))` to that local;
/// - let all the code use `_r` instead of `_2`.
///
/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
/// but rather directly use `&mut Context<'_>`, however that would currently
/// lead to higher-kinded lifetime errors.
Expand All @@ -584,95 +580,90 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
/// The async lowering step and the type / lifetime inference / checking are
/// still using the `ResumeTy` indirection for the time being, and that indirection
/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> Ty<'tcx> {
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let context_mut_ref = Ty::new_task_context(tcx);
let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
ty::GenericArgs::empty(),
body.typing_env(tcx),
tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
);

// replace the type of the `resume` argument
replace_resume_ty_local(tcx, body, CTX_ARG, context_mut_ref);
// Replace all occurrences of `CTX_ARG` with `resume_local: ResumeTy`,
// and set `CTX_ARG: &mut Context<'_>`.
let resume_local = body.local_decls.push(LocalDecl::new(context_mut_ref, body.span));
body.local_decls.swap(CTX_ARG, resume_local);
RenameLocalVisitor { from: CTX_ARG, to: resume_local, tcx }.visit_body(body);

let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
// Now `CTX_ARG` is `&mut Context` and `resume_local` is a `ResumeTy`.
// Insert a `resume_local = ResumeTy(CTX_ARG as *mut Context<'static>)`
// at the function entry to make the bridge.
let source_info = SourceInfo::outermost(body.span);
let nonnull_local = body.local_decls.push(LocalDecl::new(resume_nonnull_ty, body.span));
let nonnull_rhs =
Rvalue::Cast(CastKind::Transmute, Operand::Move(CTX_ARG.into()), resume_nonnull_ty);
let nonnull_assign = StatementKind::Assign(Box::new((nonnull_local.into(), nonnull_rhs)));
let resume_rhs = Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
resume_ty_def_id,
VariantIdx::ZERO,
ty::GenericArgs::empty(),
None,
None,
)),
indexvec![Operand::Move(nonnull_local.into())],
);
let resume_assign = StatementKind::Assign(Box::new((resume_local.into(), resume_rhs)));
body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements.splice(
0..0,
[Statement::new(source_info, nonnull_assign), Statement::new(source_info, resume_assign)],
);
}

/// HIR uses `get_context` to unwrap a `&mut Context<'_>` from a `ResumeTy`.
/// Both types are just a single pointer, but liveness analysis does not know that and
/// supposes that the operand and the destination are live at the same time.
/// Forcibly inline those calls to avoid this.
fn eliminate_get_context_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let context_mut_ref = Ty::new_task_context(tcx);
let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
ty::GenericArgs::empty(),
body.typing_env(tcx),
tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
);

for bb in body.basic_blocks.indices() {
let bb_data = &body[bb];
let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
for bb_data in body.basic_blocks.as_mut().iter_mut() {
if bb_data.is_cleanup {
continue;
}

match &bb_data.terminator().kind {
TerminatorKind::Call { func, .. } => {
let func_ty = func.ty(body, tcx);
if let ty::FnDef(def_id, _) = *func_ty.kind()
&& def_id == get_context_def_id
{
let local = eliminate_get_context_call(&mut body[bb]);
replace_resume_ty_local(tcx, body, local, context_mut_ref);
}
}
TerminatorKind::Yield { resume_arg, .. } => {
replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
}
_ => {}
let terminator = bb_data.terminator_mut();
if let TerminatorKind::Call { func, args, destination, target, .. } = &terminator.kind
&& let func_ty = func.ty(&body.local_decls, tcx)
&& let ty::FnDef(def_id, _) = *func_ty.kind()
&& def_id == get_context_def_id
&& let [arg] = &**args
&& let Some(place) = arg.node.place()
{
let arg =
Rvalue::Cast(
CastKind::Transmute,
Operand::Copy(place.project_deeper(
&[PlaceElem::Field(FieldIdx::ZERO, resume_nonnull_ty)],
tcx,
)),
context_mut_ref,
);
let assign = Statement::new(
terminator.source_info,
StatementKind::Assign(Box::new((*destination, arg))),
);
terminator.kind = TerminatorKind::Goto { target: target.unwrap() };
bb_data.statements.push(assign);
}
}
context_mut_ref
}

fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
let terminator = bb_data.terminator.take().unwrap();
let TerminatorKind::Call { args, destination, target, .. } = terminator.kind else {
bug!();
};
let [arg] = *Box::try_from(args).unwrap();
let local = arg.node.place().unwrap().local;

let arg = Rvalue::Use(arg.node, WithRetag::Yes);
let assign =
Statement::new(terminator.source_info, StatementKind::Assign(Box::new((destination, arg))));
bb_data.statements.push(assign);
bb_data.terminator = Some(Terminator {
source_info: terminator.source_info,
kind: TerminatorKind::Goto { target: target.unwrap() },
});
local
}

#[cfg_attr(not(debug_assertions), allow(unused))]
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
fn replace_resume_ty_local<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
local: Local,
context_mut_ref: Ty<'tcx>,
) {
let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
// We have to replace the `ResumeTy` that is used for type and borrow checking
// with `&mut Context<'_>` in MIR.
#[cfg(debug_assertions)]
{
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, body.span));
assert_eq!(*resume_ty_adt, expected_adt);
} else {
panic!("expected `ResumeTy`, found `{:?}`", local_ty);
};
}
}

/// Transforms the `body` of the coroutine applying the following transform:
///
/// - Remove the `resume` argument.
///
/// Ideally the async lowering would not add the `resume` argument.
///
/// The async lowering step and the type / lifetime inference / checking are
/// still using the `resume` argument for the time being. After this transform,
/// the coroutine body doesn't have the `resume` argument.
fn transform_gen_context<'tcx>(body: &mut Body<'tcx>) {
// This leaves the local representing the `resume` argument in place,
// but turns it into a regular local variable. This is cheaper than
// adjusting all local references in the body after removing it.
body.arg_count = 1;
}

struct LivenessInfo {
Expand Down Expand Up @@ -1292,6 +1283,10 @@ fn create_coroutine_resume_function<'tcx>(

pm::run_passes_no_validate(tcx, body, &[&abort_unwinding_calls::AbortUnwindingCalls], None);

if transform.coroutine_kind.is_async_desugaring() {
transform_async_context(tcx, body);
}

if let Some(dumper) = MirDumper::new(tcx, "coroutine_resume", body) {
dumper.dump_mir(body);
}
Expand Down Expand Up @@ -1507,18 +1502,15 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
// (finally in open_drop_for_tuple) before async drop expansion.
// Async drops, produced by this drop elaboration, will be expanded,
// and corresponding futures kept in layout.
let has_async_drops = matches!(
coroutine_kind,
CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
) && has_expandable_async_drops(tcx, body, coroutine_ty);
let has_async_drops = coroutine_kind.is_async_desugaring()
&& has_expandable_async_drops(tcx, body, coroutine_ty);

// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
if matches!(
coroutine_kind,
CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
) {
let context_mut_ref = transform_async_context(tcx, body);
expand_async_drops(tcx, body, context_mut_ref, coroutine_kind, coroutine_ty);
if coroutine_kind.is_async_desugaring() {
eliminate_get_context_calls(tcx, body);
}

if has_async_drops {
expand_async_drops(tcx, body, coroutine_kind, coroutine_ty);

if let Some(dumper) = MirDumper::new(tcx, "coroutine_async_drop_expand", body) {
dumper.dump_mir(body);
Expand Down Expand Up @@ -1591,13 +1583,9 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
}),
);

// Update our MIR struct to reflect the changes we've made
body.arg_count = 2; // self, resume arg
body.spread_arg = None;

// Remove the context argument within generator bodies.
if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) {
transform_gen_context(body);
body.arg_count = 1;
}

// The original arguments to the function are no longer arguments, mark them as such.
Expand Down Expand Up @@ -1652,7 +1640,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim);

// For coroutine with sync drop, generating async proxy for `future_drop_poll` call
let mut proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body);
let mut proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body, coroutine_kind);
deref_finder(tcx, &mut proxy_shim, false);
body.coroutine.as_mut().unwrap().coroutine_drop_proxy_async = Some(proxy_shim);
}
Expand Down
Loading
Loading