diff --git a/src/ir/intrinsics.cpp b/src/ir/intrinsics.cpp index b2a5e1ff11a..6dd4c072324 100644 --- a/src/ir/intrinsics.cpp +++ b/src/ir/intrinsics.cpp @@ -101,29 +101,29 @@ std::vector Intrinsics::getConfigureAllFunctions(Call* call) { return ret; } -std::vector Intrinsics::getConfigureAllFunctions() { - // ConfigureAll in a start function makes its functions callable. +std::vector Intrinsics::getJSCalledFunctions() { + std::vector ret; + for (auto& func : module.functions) { + if (getAnnotations(func.get()).jsCalled) { + ret.push_back(func->name); + } + } + + // ConfigureAlls in a start function make their functions callable. if (module.start) { auto* start = module.getFunction(module.start); if (!start->imported()) { FindAll calls(start->body); - // Look for the (single) configureAll. - Call* configureAll = nullptr; for (auto* call : calls.list) { if (isConfigureAll(call)) { - if (configureAll) { - Fatal() << "Multiple configureAlls"; - } else { - configureAll = call; + for (auto name : getConfigureAllFunctions(call)) { + ret.push_back(name); } } } - if (configureAll) { - return getConfigureAllFunctions(configureAll); - } } } - return {}; + return ret; } } // namespace wasm diff --git a/src/ir/intrinsics.h b/src/ir/intrinsics.h index 51ac3a6436b..9be2caebc7d 100644 --- a/src/ir/intrinsics.h +++ b/src/ir/intrinsics.h @@ -109,8 +109,11 @@ class Intrinsics { // // where the segment $seg is of size N. std::vector getConfigureAllFunctions(Call* call); - // As above, but looks through the module to find the configureAll. - std::vector getConfigureAllFunctions(); + + // Returns the names of all functions that are JS-called. That includes ones + // in configureAll (which we look through the module for), and also those + // annotated with @binaryen.js.called. + std::vector getJSCalledFunctions(); // Get the code annotations for an expression in a function. CodeAnnotation getAnnotations(Expression* curr, Function* func) { @@ -149,6 +152,9 @@ class Intrinsics { if (!ret.removableIfUnused) { ret.removableIfUnused = funcAnnotations.removableIfUnused; } + if (!ret.jsCalled) { + ret.jsCalled = funcAnnotations.jsCalled; + } } return ret; diff --git a/src/ir/possible-contents.cpp b/src/ir/possible-contents.cpp index c018cf628d7..06825343d1f 100644 --- a/src/ir/possible-contents.cpp +++ b/src/ir/possible-contents.cpp @@ -2454,8 +2454,8 @@ Flower::Flower(Module& wasm, const PassOptions& options) } } - // configureAll functions are called from outside the module, as if exported. - for (auto func : Intrinsics(wasm).getConfigureAllFunctions()) { + // JS-called functions are called from outside the module, as if exported. + for (auto func : Intrinsics(wasm).getJSCalledFunctions()) { calledFromOutside.insert(func); } diff --git a/src/parser/contexts.h b/src/parser/contexts.h index 3978afff45d..5ab172224e9 100644 --- a/src/parser/contexts.h +++ b/src/parser/contexts.h @@ -1317,6 +1317,8 @@ struct AnnotationParserCtx { inlineHint = &a; } else if (a.kind == Annotations::RemovableIfUnusedHint) { ret.removableIfUnused = true; + } else if (a.kind == Annotations::JSCalledHint) { + ret.jsCalled = true; } } diff --git a/src/passes/GlobalTypeOptimization.cpp b/src/passes/GlobalTypeOptimization.cpp index 5ea29f8e335..5ecc9ebae0a 100644 --- a/src/passes/GlobalTypeOptimization.cpp +++ b/src/passes/GlobalTypeOptimization.cpp @@ -420,7 +420,7 @@ struct GlobalTypeOptimization : public Pass { if (!wasm.features.hasCustomDescriptors()) { return; } - for (auto func : Intrinsics(wasm).getConfigureAllFunctions()) { + for (auto func : Intrinsics(wasm).getJSCalledFunctions()) { // Look at the result types being returned to JS and make sure we preserve // any configured prototypes they might expose. for (auto type : wasm.getFunction(func)->getResults()) { diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index d92a412dbbd..35b5a3076ec 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -2797,6 +2797,12 @@ void PrintSExpression::printCodeAnnotations(Expression* curr) { restoreNormalColor(o); doIndent(o, indent); } + if (annotation.jsCalled) { + Colors::grey(o); + o << "(@" << Annotations::JSCalledHint << ")\n"; + restoreNormalColor(o); + doIndent(o, indent); + } } } diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp index 64f56744624..c52f9872246 100644 --- a/src/passes/RemoveUnusedModuleElements.cpp +++ b/src/passes/RemoveUnusedModuleElements.cpp @@ -184,7 +184,16 @@ struct Noter : public PostWalker> { noteCallRef(curr->target->type.getHeapType()); } - void visitRefFunc(RefFunc* curr) { noteRefFunc(curr->func); } + void visitRefFunc(RefFunc* curr) { + // If the target is js-called then a reference is as strong as a use. + auto target = curr->func; + Intrinsics intrinsics(*getModule()); + if (intrinsics.getAnnotations(getModule()->getFunction(target)).jsCalled) { + use({ModuleElementKind::Function, target}); + } else { + noteRefFunc(target); + } + } void visitStructGet(StructGet* curr) { if (curr->ref->type == Type::unreachable || curr->ref->type.isNull()) { diff --git a/src/passes/SignaturePruning.cpp b/src/passes/SignaturePruning.cpp index d1f1a14456d..cb5155a7a14 100644 --- a/src/passes/SignaturePruning.cpp +++ b/src/passes/SignaturePruning.cpp @@ -176,9 +176,10 @@ struct SignaturePruning : public Pass { allInfo[tag->type].optimizable = false; } - // configureAll functions are signature-called, and must also not be - // modified. - for (auto func : Intrinsics(*module).getConfigureAllFunctions()) { + // Signature-called functions must also not be modified. + // TODO: Explore whether removing parameters from the end could be + // beneficial (check if it does not regress call performance with JS). + for (auto func : Intrinsics(*module).getJSCalledFunctions()) { allInfo[module->getFunction(func)->type.getHeapType()].optimizable = false; } diff --git a/src/passes/SignatureRefining.cpp b/src/passes/SignatureRefining.cpp index b86a4a9eb60..67132733ec1 100644 --- a/src/passes/SignatureRefining.cpp +++ b/src/passes/SignatureRefining.cpp @@ -162,9 +162,8 @@ struct SignatureRefining : public Pass { } } - // configureAll functions are signature-called, which means their params - // must not be refined. - for (auto func : Intrinsics(*module).getConfigureAllFunctions()) { + // Signature-called functions must not have params refined. + for (auto func : Intrinsics(*module).getJSCalledFunctions()) { allInfo[module->getFunction(func)->type.getHeapType()].canModifyParams = false; } diff --git a/src/passes/StripToolchainAnnotations.cpp b/src/passes/StripToolchainAnnotations.cpp index 6118998eabf..c590afe964d 100644 --- a/src/passes/StripToolchainAnnotations.cpp +++ b/src/passes/StripToolchainAnnotations.cpp @@ -43,6 +43,7 @@ struct StripToolchainAnnotations // Remove the toolchain-specific annotations. auto& annotation = iter->second; annotation.removableIfUnused = false; + annotation.jsCalled = false; // If nothing remains, remove the entire annotation. if (annotation == CodeAnnotation()) { diff --git a/src/passes/Unsubtyping.cpp b/src/passes/Unsubtyping.cpp index 08e1c47feda..629d86c34c7 100644 --- a/src/passes/Unsubtyping.cpp +++ b/src/passes/Unsubtyping.cpp @@ -612,7 +612,7 @@ struct Unsubtyping : Pass, Noter { return; } Type anyref(HeapType::any, Nullable); - for (auto func : Intrinsics(wasm).getConfigureAllFunctions()) { + for (auto func : Intrinsics(wasm).getJSCalledFunctions()) { // Parameter types flow into Wasm and are implicitly cast from any. for (auto type : wasm.getFunction(func)->getParams()) { if (Type::isSubType(type, anyref)) { diff --git a/src/wasm-annotations.h b/src/wasm-annotations.h index fec7f067795..17911a78ecb 100644 --- a/src/wasm-annotations.h +++ b/src/wasm-annotations.h @@ -28,6 +28,7 @@ namespace wasm::Annotations { extern const Name BranchHint; extern const Name InlineHint; extern const Name RemovableIfUnusedHint; +extern const Name JSCalledHint; } // namespace wasm::Annotations diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 3bd3199b67a..125040aa2d2 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -1442,6 +1442,7 @@ class WasmBinaryWriter { std::optional getBranchHintsBuffer(); std::optional getInlineHintsBuffer(); std::optional getRemovableIfUnusedHintsBuffer(); + std::optional getJSCalledHintsBuffer(); // helpers void writeInlineString(std::string_view name); @@ -1734,7 +1735,8 @@ class WasmBinaryReader { void readBranchHints(size_t payloadLen); void readInlineHints(size_t payloadLen); - void readremovableIfUnusedHints(size_t payloadLen); + void readRemovableIfUnusedHints(size_t payloadLen); + void readJSCalledHints(size_t payloadLen); std::tuple readMemoryAccess(bool isAtomic, bool isRMW); diff --git a/src/wasm.h b/src/wasm.h index e74d6532ac2..87258b14c36 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -2245,15 +2245,23 @@ struct CodeAnnotation { static const uint8_t AlwaysInline = 127; std::optional inline_; - // Toolchain hint: If this expression's result is unused, then the entire - // thing can be considered dead and removable. See - // + // Toolchain hints, see // https://github.com/WebAssembly/binaryen/wiki/Optimizer-Cookbook#intrinsics + + // If this expression's result is unused, then the entire thing can be + // considered dead and removable. bool removableIfUnused = false; + // This should be assumed to be called from JS, even in closed world. Being + // called from JS means that the call happens in a non-typed way, with only + // the signature mattering ("signature-called"). In particular, rec group type + // identity does not matter for such functions. + bool jsCalled = false; + bool operator==(const CodeAnnotation& other) const { return branchLikely == other.branchLikely && inline_ == other.inline_ && - removableIfUnused == other.removableIfUnused; + removableIfUnused == other.removableIfUnused && + jsCalled == other.jsCalled; } }; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 85768b55fbc..23077d236a9 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -1628,6 +1628,7 @@ std::optional WasmBinaryWriter::writeCodeAnnotations() { append(getBranchHintsBuffer()); append(getInlineHintsBuffer()); append(getRemovableIfUnusedHintsBuffer()); + append(getJSCalledHintsBuffer()); return ret; } @@ -1783,6 +1784,17 @@ WasmBinaryWriter::getRemovableIfUnusedHintsBuffer() { }); } +std::optional +WasmBinaryWriter::getJSCalledHintsBuffer() { + return writeExpressionHints( + Annotations::JSCalledHint, + [](const CodeAnnotation& annotation) { return annotation.jsCalled; }, + [](const CodeAnnotation& annotation, BufferWithRandomAccess& buffer) { + // Hint size, always empty. + buffer << U32LEB(0); + }); +} + void WasmBinaryWriter::writeData(const char* data, size_t size) { for (size_t i = 0; i < size; i++) { o << int8_t(data[i]); @@ -2055,7 +2067,8 @@ void WasmBinaryReader::preScan() { if (sectionName == Annotations::BranchHint || sectionName == Annotations::InlineHint || - sectionName == Annotations::RemovableIfUnusedHint) { + sectionName == Annotations::RemovableIfUnusedHint || + sectionName == Annotations::JSCalledHint) { // Code annotations require code locations. // TODO: We could note which functions require code locations, as an // optimization. @@ -2215,8 +2228,11 @@ void WasmBinaryReader::readCustomSection(size_t payloadLen) { } else if (sectionName == Annotations::RemovableIfUnusedHint) { deferredAnnotationSections.push_back( AnnotationSectionInfo{pos, [this, payloadLen]() { - this->readremovableIfUnusedHints(payloadLen); + this->readRemovableIfUnusedHints(payloadLen); }}); + } else if (sectionName == Annotations::JSCalledHint) { + deferredAnnotationSections.push_back(AnnotationSectionInfo{ + pos, [this, payloadLen]() { this->readJSCalledHints(payloadLen); }}); } else { // an unfamiliar custom section if (sectionName.equals(BinaryConsts::CustomSections::Linking)) { @@ -5527,7 +5543,7 @@ void WasmBinaryReader::readInlineHints(size_t payloadLen) { }); } -void WasmBinaryReader::readremovableIfUnusedHints(size_t payloadLen) { +void WasmBinaryReader::readRemovableIfUnusedHints(size_t payloadLen) { readExpressionHints(Annotations::RemovableIfUnusedHint, payloadLen, [&](CodeAnnotation& annotation) { @@ -5540,6 +5556,18 @@ void WasmBinaryReader::readremovableIfUnusedHints(size_t payloadLen) { }); } +void WasmBinaryReader::readJSCalledHints(size_t payloadLen) { + readExpressionHints( + Annotations::JSCalledHint, payloadLen, [&](CodeAnnotation& annotation) { + auto size = getU32LEB(); + if (size != 0) { + throwError("bad jsCalledHint size"); + } + + annotation.jsCalled = true; + }); +} + std::tuple WasmBinaryReader::readMemoryAccess(bool isAtomic, bool isRMW) { auto rawAlignment = getU32LEB(); diff --git a/src/wasm/wasm-ir-builder.cpp b/src/wasm/wasm-ir-builder.cpp index fd656f9b0fb..23ff8764971 100644 --- a/src/wasm/wasm-ir-builder.cpp +++ b/src/wasm/wasm-ir-builder.cpp @@ -2664,16 +2664,19 @@ void IRBuilder::applyAnnotations(Expression* expr, } if (annotation.inline_) { - // Only possible inside functions. assert(func); func->codeAnnotations[expr].inline_ = annotation.inline_; } if (annotation.removableIfUnused) { - // Only possible inside functions. assert(func); func->codeAnnotations[expr].removableIfUnused = true; } + + if (annotation.jsCalled) { + assert(func); + func->codeAnnotations[expr].jsCalled = true; + } } } // namespace wasm diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 4e901c152de..959b6cd4bfe 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -70,6 +70,7 @@ namespace Annotations { const Name BranchHint = "metadata.code.branch_hint"; const Name InlineHint = "metadata.code.inline"; const Name RemovableIfUnusedHint = "binaryen.removable.if.unused"; +const Name JSCalledHint = "binaryen.js.called"; } // namespace Annotations diff --git a/test/lit/js-called.wast b/test/lit/js-called.wast new file mode 100644 index 00000000000..5766fb94550 --- /dev/null +++ b/test/lit/js-called.wast @@ -0,0 +1,24 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. + +;; RUN: wasm-opt -all %s -S -o - | filecheck %s +;; RUN: wasm-opt -all --roundtrip %s -S -o - | filecheck %s + +;; Test text and binary handling of @binaryen.js.called. + +(module + ;; CHECK: (type $0 (func)) + + ;; CHECK: (@binaryen.js.called) + ;; CHECK-NEXT: (func $func-annotation (type $0) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (@binaryen.js.called) + (func $func-annotation + (drop + (i32.const 0) + ) + ) +) + diff --git a/test/lit/passes/remove-unused-module-elements-js-called.wast b/test/lit/passes/remove-unused-module-elements-js-called.wast new file mode 100644 index 00000000000..f26223051e5 --- /dev/null +++ b/test/lit/passes/remove-unused-module-elements-js-called.wast @@ -0,0 +1,39 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. + +;; RUN: foreach %s %t wasm-opt --remove-unused-module-elements --closed-world -all -S -o - | filecheck %s + +(module + ;; CHECK: (type $0 (func)) + + ;; CHECK: (elem declare func $js.called.referred) + + ;; CHECK: (export "export" (func $export)) + + ;; CHECK: (@binaryen.js.called) + ;; CHECK-NEXT: (func $js.called.referred (type $0) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.const 10) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (@binaryen.js.called) + (func $js.called.referred + ;; This is jsCalled, and referred below, so it is kept. + (drop (i32.const 10)) + ) + + ;; CHECK: (func $export (type $0) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (ref.func $js.called.referred) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $export (export "export") + (drop (ref.func $js.called.referred)) + ) + + (@binaryen.js.called) + (func $js.called.unreferred + ;; This is jsCalled, and not referred anywhere. The annotation does not + ;; stop the function from being removed entirely. + (drop (i32.const 20)) + ) +) diff --git a/test/lit/passes/strip-toolchain-annotations-func.wast b/test/lit/passes/strip-toolchain-annotations-func.wast index 564676179ed..80cf695273c 100644 --- a/test/lit/passes/strip-toolchain-annotations-func.wast +++ b/test/lit/passes/strip-toolchain-annotations-func.wast @@ -31,8 +31,9 @@ ;; CHECK-NEXT: ) (@binaryen.removable.if.unused) (@metadata.code.inline "\00") + (@binaryen.js.called) (func $test-func-d - ;; Reverse order of above. + ;; Reverse order of above, and also includes js.called which is removed. ) )