@@ -1061,6 +1061,40 @@ object desugar {
10611061 name
10621062 }
10631063
1064+ /** Strip parens and empty blocks around the body of `tree`. */
1065+ def normalizePolyFunction (tree : PolyFunction )(using Context ): PolyFunction =
1066+ def stripped (body : Tree ): Tree = body match
1067+ case Parens (body1) =>
1068+ stripped(body1)
1069+ case Block (Nil , body1) =>
1070+ stripped(body1)
1071+ case _ => body
1072+ cpy.PolyFunction (tree)(tree.targs, stripped(tree.body)).asInstanceOf [PolyFunction ]
1073+
1074+ /** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
1075+ * Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1076+ */
1077+ def makePolyFunctionType (tree : PolyFunction )(using Context ): RefinedTypeTree =
1078+ val PolyFunction (tparams : List [untpd.TypeDef ] @ unchecked, fun @ untpd.Function (vparamTypes, res)) = tree : @ unchecked
1079+ val funFlags = fun match
1080+ case fun : FunctionWithMods =>
1081+ fun.mods.flags
1082+ case _ => EmptyFlags
1083+
1084+ // TODO: make use of this in the desugaring when pureFuns is enabled.
1085+ // val isImpure = funFlags.is(Impure)
1086+
1087+ // Function flags to be propagated to each parameter in the desugared method type.
1088+ val paramFlags = funFlags.toTermFlags & Given
1089+ val vparams = vparamTypes.zipWithIndex.map:
1090+ case (p : ValDef , _) => p.withAddedFlags(paramFlags)
1091+ case (p, n) => makeSyntheticParameter(n + 1 , p).withAddedFlags(paramFlags)
1092+
1093+ RefinedTypeTree (ref(defn.PolyFunctionType ), List (
1094+ DefDef (nme.apply, tparams :: vparams :: Nil , res, EmptyTree ).withFlags(Synthetic )
1095+ )).withSpan(tree.span)
1096+ end makePolyFunctionType
1097+
10641098 /** Invent a name for an anonympus given of type or template `impl`. */
10651099 def inventGivenOrExtensionName (impl : Tree )(using Context ): SimpleName =
10661100 val str = impl match
@@ -1454,17 +1488,20 @@ object desugar {
14541488 }
14551489
14561490 /** Make closure corresponding to function.
1457- * params => body
1491+ * [tparams] => params => body
14581492 * ==>
1459- * def $anonfun(params) = body
1493+ * def $anonfun[tparams] (params) = body
14601494 * Closure($anonfun)
14611495 */
1462- def makeClosure (params : List [ValDef ], body : Tree , tpt : Tree | Null = null , isContextual : Boolean , span : Span )(using Context ): Block =
1496+ def makeClosure (tparams : List [TypeDef ], vparams : List [ValDef ], body : Tree , tpt : Tree | Null = null , span : Span )(using Context ): Block =
1497+ val paramss : List [ParamClause ] =
1498+ if tparams.isEmpty then vparams :: Nil
1499+ else tparams :: vparams :: Nil
14631500 Block (
1464- DefDef (nme.ANON_FUN , params :: Nil , if (tpt == null ) TypeTree () else tpt, body)
1501+ DefDef (nme.ANON_FUN , paramss , if (tpt == null ) TypeTree () else tpt, body)
14651502 .withSpan(span)
14661503 .withMods(synthetic | Artifact ),
1467- Closure (Nil , Ident (nme.ANON_FUN ), if (isContextual) ContextualEmptyTree else EmptyTree ))
1504+ Closure (Nil , Ident (nme.ANON_FUN ), EmptyTree ))
14681505
14691506 /** If `nparams` == 1, expand partial function
14701507 *
@@ -1753,62 +1790,6 @@ object desugar {
17531790 }
17541791 }
17551792
1756- def makePolyFunction (targs : List [Tree ], body : Tree , pt : Type ): Tree = body match {
1757- case Parens (body1) =>
1758- makePolyFunction(targs, body1, pt)
1759- case Block (Nil , body1) =>
1760- makePolyFunction(targs, body1, pt)
1761- case Function (vargs, res) =>
1762- assert(targs.nonEmpty)
1763- // TODO: Figure out if we need a `PolyFunctionWithMods` instead.
1764- val mods = body match {
1765- case body : FunctionWithMods => body.mods
1766- case _ => untpd.EmptyModifiers
1767- }
1768- val polyFunctionTpt = ref(defn.PolyFunctionType )
1769- val applyTParams = targs.asInstanceOf [List [TypeDef ]]
1770- if (ctx.mode.is(Mode .Type )) {
1771- // Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
1772- // Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1773-
1774- val applyVParams = vargs.zipWithIndex.map {
1775- case (p : ValDef , _) => p.withAddedFlags(mods.flags)
1776- case (p, n) => makeSyntheticParameter(n + 1 , p).withAddedFlags(mods.flags.toTermFlags)
1777- }
1778- RefinedTypeTree (polyFunctionTpt, List (
1779- DefDef (nme.apply, applyTParams :: applyVParams :: Nil , res, EmptyTree ).withFlags(Synthetic )
1780- ))
1781- }
1782- else {
1783- // Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
1784- // with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R
1785- // Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body }
1786- // where R2 is R, with all references to S_1..S_M replaced with T1..T_M.
1787-
1788- def typeTree (tp : Type ) = tp match
1789- case RefinedType (parent, nme.apply, PolyType (_, mt)) if parent.typeSymbol eq defn.PolyFunctionClass =>
1790- var bail = false
1791- def mapper (tp : Type , topLevel : Boolean = false ): Tree = tp match
1792- case tp : TypeRef => ref(tp)
1793- case tp : TypeParamRef => Ident (applyTParams(tp.paramNum).name)
1794- case AppliedType (tycon, args) => AppliedTypeTree (mapper(tycon), args.map(mapper(_)))
1795- case _ => if topLevel then TypeTree () else { bail = true ; genericEmptyTree }
1796- val mapped = mapper(mt.resultType, topLevel = true )
1797- if bail then TypeTree () else mapped
1798- case _ => TypeTree ()
1799-
1800- val applyVParams = vargs.asInstanceOf [List [ValDef ]]
1801- .map(varg => varg.withAddedFlags(mods.flags | Param ))
1802- New (Template (emptyConstructor, List (polyFunctionTpt), Nil , EmptyValDef ,
1803- List (DefDef (nme.apply, applyTParams :: applyVParams :: Nil , typeTree(pt), res))
1804- ))
1805- }
1806- case _ =>
1807- // may happen for erroneous input. An error will already have been reported.
1808- assert(ctx.reporter.errorsReported)
1809- EmptyTree
1810- }
1811-
18121793 // begin desugar
18131794
18141795 // Special case for `Parens` desugaring: unlike all the desugarings below,
@@ -1821,8 +1802,6 @@ object desugar {
18211802 }
18221803
18231804 val desugared = tree match {
1824- case PolyFunction (targs, body) =>
1825- makePolyFunction(targs, body, pt) orElse tree
18261805 case SymbolLit (str) =>
18271806 Apply (
18281807 ref(defn.ScalaSymbolClass .companionModule.termRef),
0 commit comments