From 72f70b95fe3e46f622de84a2d38d2ea6161523cc Mon Sep 17 00:00:00 2001 From: James Date: Wed, 1 Apr 2026 00:35:20 +0100 Subject: [PATCH] fix: align export async function checks --- .../src/transforms/proxy-export.test.ts | 47 ++++++++++++++++++- .../plugin-rsc/src/transforms/proxy-export.ts | 33 +++---------- .../src/transforms/wrap-export.test.ts | 3 ++ .../plugin-rsc/src/transforms/wrap-export.ts | 47 +++++++++++-------- 4 files changed, 83 insertions(+), 47 deletions(-) diff --git a/packages/plugin-rsc/src/transforms/proxy-export.test.ts b/packages/plugin-rsc/src/transforms/proxy-export.test.ts index 4be88ac00..bfce385d6 100644 --- a/packages/plugin-rsc/src/transforms/proxy-export.test.ts +++ b/packages/plugin-rsc/src/transforms/proxy-export.test.ts @@ -4,7 +4,10 @@ import { transformProxyExport } from './proxy-export' import { debugSourceMap } from './test-utils' import { transformWrapExport } from './wrap-export' -async function testTransform(input: string, options?: { keep?: boolean }) { +async function testTransform( + input: string, + options?: { keep?: boolean; rejectNonAsyncFunction?: boolean }, +) { const ast = await parseAstAsync(input) const result = transformProxyExport(ast, { code: input, @@ -239,4 +242,46 @@ export const MyClientComp = () => { throw new Error('...') } } `) }) + + test('reject non async function', async () => { + const accepted = [ + 'export async function f() {}', + 'export default async function f() {}', + 'export const fn = async function fn() {}', + 'export const fn = async () => {}', + 'export const fn = async () => {}, fn2 = x', + 'export const fn = x', + 'export const fn = x({ x: y })', + 'export const fn = x(async () => {})', + 'export default x', + 'const y = x; export { y }', + 'export const fn = x(() => {})', + 'export const testAction = actionClient.action(async () => { return { message: "Hello, world!" }; });', + ] + + const rejected = [ + 'export function f() {}', + 'export default function f() {}', + 'export const fn = function fn() {}', + 'export const fn = () => {}', + 'export const fn = x, fn2 = () => {}', + 'export class Cls {}', + 'export const Cls = class {}', + 'export const Cls = class Foo {}', + ] + + for (const code of accepted) { + await expect( + testTransform(code, { rejectNonAsyncFunction: true }), + ).resolves.not.toThrow() + } + + for (const code of rejected) { + await expect( + testTransform(code, { rejectNonAsyncFunction: true }), + ).rejects.toThrow(/unsupported non async function/) + } + + expect.assertions(rejected.length + accepted.length) + }) }) diff --git a/packages/plugin-rsc/src/transforms/proxy-export.ts b/packages/plugin-rsc/src/transforms/proxy-export.ts index f08a05701..36bc06d07 100644 --- a/packages/plugin-rsc/src/transforms/proxy-export.ts +++ b/packages/plugin-rsc/src/transforms/proxy-export.ts @@ -3,6 +3,7 @@ import type { Node, Program } from 'estree' import MagicString from 'magic-string' import { extract_names } from 'periscopic' import { hasDirective } from './utils' +import { validateNonAsyncFunction } from './wrap-export' export type TransformProxyExportOptions = { /** Required for source map and `keep` options */ @@ -59,14 +60,6 @@ export function transformProxyExport( output.update(node.start, node.end, newCode) } - function validateNonAsyncFunction(node: Node, ok?: boolean) { - if (options.rejectNonAsyncFunction && !ok) { - throw Object.assign(new Error(`unsupported non async function`), { - pos: node.start, - }) - } - } - for (const node of ast.body) { if (node.type === 'ExportNamedDeclaration') { if (node.declaration) { @@ -77,24 +70,15 @@ export function transformProxyExport( /** * export function foo() {} */ - validateNonAsyncFunction( - node, - node.declaration.type === 'FunctionDeclaration' && - node.declaration.async, - ) + validateNonAsyncFunction(options, node.declaration) createExport(node, [node.declaration.id.name]) } else if (node.declaration.type === 'VariableDeclaration') { /** * export const foo = 1, bar = 2 */ - validateNonAsyncFunction( - node, - node.declaration.declarations.every( - (decl) => - decl.init?.type === 'ArrowFunctionExpression' && - decl.init.async, - ), - ) + for (const decl of node.declaration.declarations) { + if (decl.init) validateNonAsyncFunction(options, decl.init) + } if (options.keep && options.code) { if (node.declaration.declarations.length === 1) { const decl = node.declaration.declarations[0]! @@ -149,12 +133,7 @@ export function transformProxyExport( * export default () => {} */ if (node.type === 'ExportDefaultDeclaration') { - validateNonAsyncFunction( - node, - node.declaration.type === 'Identifier' || - (node.declaration.type === 'FunctionDeclaration' && - node.declaration.async), - ) + validateNonAsyncFunction(options, node.declaration) createExport(node, ['default']) continue } diff --git a/packages/plugin-rsc/src/transforms/wrap-export.test.ts b/packages/plugin-rsc/src/transforms/wrap-export.test.ts index 28dfa19a4..73cadd1d8 100644 --- a/packages/plugin-rsc/src/transforms/wrap-export.test.ts +++ b/packages/plugin-rsc/src/transforms/wrap-export.test.ts @@ -321,6 +321,7 @@ export default Page; `export default x`, `const y = x; export { y }`, `export const fn = x(() => {})`, // rejected by next.js + `export const testAction = actionClient.action(async () => { return { message: "Hello, world!" }; });`, ] const rejected = [ @@ -330,6 +331,8 @@ export default Page; `export const fn = () => {}`, `export const fn = x, fn2 = () => {}`, `export class Cls {}`, + `export const Cls = class {}`, + `export const Cls = class Foo {}`, ] async function toActual(input: string) { diff --git a/packages/plugin-rsc/src/transforms/wrap-export.ts b/packages/plugin-rsc/src/transforms/wrap-export.ts index 5c6cdcae4..708fe561b 100644 --- a/packages/plugin-rsc/src/transforms/wrap-export.ts +++ b/packages/plugin-rsc/src/transforms/wrap-export.ts @@ -1,5 +1,10 @@ import { tinyassert } from '@hiogawa/utils' -import type { Node, Program } from 'estree' +import type { + MaybeNamedClassDeclaration, + MaybeNamedFunctionDeclaration, + Node, + Program, +} from 'estree' import MagicString from 'magic-string' import { extract_names } from 'periscopic' @@ -21,6 +26,25 @@ export type TransformWrapExportOptions = { filter?: TransformWrapExportFilter } +export function validateNonAsyncFunction( + opts: { rejectNonAsyncFunction?: boolean }, + node: Node | MaybeNamedFunctionDeclaration | MaybeNamedClassDeclaration, +): void { + if (!opts.rejectNonAsyncFunction) return + if ( + node.type === 'ClassDeclaration' || + node.type === 'ClassExpression' || + ((node.type === 'FunctionDeclaration' || + node.type === 'FunctionExpression' || + node.type === 'ArrowFunctionExpression') && + !node.async) + ) { + throw Object.assign(new Error(`unsupported non async function`), { + pos: node.start, + }) + } +} + export function transformWrapExport( input: string, ast: Program, @@ -83,21 +107,6 @@ export function transformWrapExport( ) } - function validateNonAsyncFunction(node: Node) { - if (!options.rejectNonAsyncFunction) return - if ( - node.type === 'ClassDeclaration' || - ((node.type === 'FunctionDeclaration' || - node.type === 'FunctionExpression' || - node.type === 'ArrowFunctionExpression') && - !node.async) - ) { - throw Object.assign(new Error(`unsupported non async function`), { - pos: node.start, - }) - } - } - for (const node of ast.body) { // named exports if (node.type === 'ExportNamedDeclaration') { @@ -109,7 +118,7 @@ export function transformWrapExport( /** * export function foo() {} */ - validateNonAsyncFunction(node.declaration) + validateNonAsyncFunction(options, node.declaration) const name = node.declaration.id.name wrapSimple(node.start, node.declaration.start, [ { name, meta: { isFunction: true, declName: name } }, @@ -120,7 +129,7 @@ export function transformWrapExport( */ for (const decl of node.declaration.declarations) { if (decl.init) { - validateNonAsyncFunction(decl.init) + validateNonAsyncFunction(options, decl.init) } } if (node.declaration.kind === 'const') { @@ -203,7 +212,7 @@ export function transformWrapExport( * export default () => {} */ if (node.type === 'ExportDefaultDeclaration') { - validateNonAsyncFunction(node.declaration as Node) + validateNonAsyncFunction(options, node.declaration) let localName: string let isFunction = false let declName: string | undefined