User: joe Date: 12 Jun 25 00:01 Revision: 341f481a921b0826535194f4dacf698de07b53ae Summary: Lambda desugarer TeamCity URL: http://ci.mcdev.io:80/viewModification.html?tab=vcsModificationFiles&modId=10072&personal=false Index: src/main/kotlin/platform/mixin/expression/MEExpressionMatchUtil.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/MEExpressionMatchUtil.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/MEExpressionMatchUtil.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -20,8 +20,10 @@ package com.demonwav.mcdev.platform.mixin.expression +import com.demonwav.mcdev.platform.mixin.expression.psi.MEMatchableElement import com.demonwav.mcdev.platform.mixin.handlers.InjectorAnnotationHandler import com.demonwav.mcdev.platform.mixin.handlers.MixinAnnotationHandler +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil import com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.CollectVisitor import com.demonwav.mcdev.platform.mixin.util.LocalInfo import com.demonwav.mcdev.platform.mixin.util.MixinConstants @@ -30,6 +32,8 @@ import com.demonwav.mcdev.util.constantStringValue import com.demonwav.mcdev.util.descriptor import com.demonwav.mcdev.util.findAnnotations +import com.demonwav.mcdev.util.findContainingClass +import com.demonwav.mcdev.util.fullQualifiedName import com.demonwav.mcdev.util.resolveType import com.demonwav.mcdev.util.resolveTypeArray import com.github.benmanes.caffeine.cache.Caffeine @@ -38,7 +42,10 @@ import com.intellij.openapi.progress.ProcessCanceledException import com.intellij.openapi.progress.ProgressManager import com.intellij.openapi.project.Project +import com.intellij.psi.PsiExpression +import com.intellij.psi.PsiMethodCallExpression import com.intellij.psi.PsiModifierList +import com.intellij.psi.util.PsiUtil import com.llamalad7.mixinextras.expression.impl.ExpressionParserFacade import com.llamalad7.mixinextras.expression.impl.ExpressionService import com.llamalad7.mixinextras.expression.impl.ast.expressions.Expression @@ -46,6 +53,7 @@ import com.llamalad7.mixinextras.expression.impl.flow.FlowInterpreter import com.llamalad7.mixinextras.expression.impl.flow.FlowValue import com.llamalad7.mixinextras.expression.impl.flow.expansion.InsnExpander +import com.llamalad7.mixinextras.expression.impl.flow.postprocessing.LMFInfo import com.llamalad7.mixinextras.expression.impl.point.ExpressionContext import com.llamalad7.mixinextras.expression.impl.pool.IdentifierPool import com.llamalad7.mixinextras.expression.impl.pool.SimpleMemberDefinition @@ -318,3 +326,36 @@ val decorations: Map, ) } + +fun PsiExpression.matchesFlow(meMatchableElement: MEMatchableElement, context: MESourceMatchContext): Boolean { + val expr = PsiUtil.skipParenthesizedExprDown(this) ?: this + if (!DesugarUtil.isValuePopped(expr) && meMatchableElement.matchesJava(expr, context)) { + return true + } + for (other in DesugarUtil.getOtherFlowChildren(expr)) { + val actualOther = PsiUtil.skipParenthesizedExprDown(other) ?: other + if (!DesugarUtil.isValuePopped(actualOther) && meMatchableElement.matchesJava(actualOther, context)) { + return true + } + } + + return false +} + +fun DesugarUtil.IndyData.lmfType(methodCall: PsiMethodCallExpression): LMFInfo.Type? { + val impl = bsmArgs.getOrNull(1) as? Handle ?: return null + val bound = methodCall.argumentList.expressionCount != 0 + val containingClass = methodCall.findContainingClass() ?: return null + when (impl.tag) { + Opcodes.H_NEWINVOKESPECIAL -> return if (bound) null else LMFInfo.Type.INSTANTIATION + Opcodes.H_INVOKESPECIAL -> { + if (impl.owner != containingClass.fullQualifiedName?.replace('.', '/')) { + return null + } + return if (bound) LMFInfo.Type.BOUND_METHOD else LMFInfo.Type.FREE_METHOD + } + Opcodes.H_INVOKEVIRTUAL, Opcodes.H_INVOKEINTERFACE -> return if (bound) LMFInfo.Type.BOUND_METHOD else LMFInfo.Type.FREE_METHOD + Opcodes.H_INVOKESTATIC -> return LMFInfo.Type.FREE_METHOD + } + return null +} Index: src/main/kotlin/platform/mixin/expression/psi/mixins/METypeMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/METypeMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/METypeMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -23,10 +23,12 @@ import com.demonwav.mcdev.platform.mixin.expression.MESourceMatchContext import com.intellij.psi.PsiElement import com.intellij.psi.PsiType +import org.objectweb.asm.Type interface METypeMixin : PsiElement { val isArray: Boolean val dimensions: Int + fun matches(type: Type, context: MESourceMatchContext): Boolean fun matchesJava(java: PsiType, context: MESourceMatchContext): Boolean } Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEArgumentsImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEArgumentsImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEArgumentsImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -22,11 +22,11 @@ import com.demonwav.mcdev.platform.mixin.expression.MESourceMatchContext import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpression +import com.demonwav.mcdev.platform.mixin.expression.matchesFlow import com.demonwav.mcdev.platform.mixin.expression.psi.mixins.MEArgumentsMixin import com.intellij.extapi.psi.ASTWrapperPsiElement import com.intellij.lang.ASTNode import com.intellij.psi.PsiExpression -import com.intellij.psi.util.PsiUtil abstract class MEArgumentsImplMixin(node: ASTNode) : ASTWrapperPsiElement(node), MEArgumentsMixin { override fun matchesJava(java: Array, context: MESourceMatchContext): Boolean { @@ -35,8 +35,7 @@ return false } return exprs.asSequence().zip(java.asSequence()).all { (expr, javaExpr) -> - val actualJavaExpr = PsiUtil.skipParenthesizedExprDown(javaExpr) ?: return@all false - expr.matchesJava(actualJavaExpr, context) + javaExpr.matchesFlow(expr, context) } } Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEArrayAccessExpressionImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEArrayAccessExpressionImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEArrayAccessExpressionImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -24,6 +24,7 @@ import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpression import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpressionTypes import com.demonwav.mcdev.platform.mixin.expression.gen.psi.impl.MEExpressionImpl +import com.demonwav.mcdev.platform.mixin.expression.matchesFlow import com.demonwav.mcdev.platform.mixin.expression.psi.MEPsiUtil import com.demonwav.mcdev.platform.mixin.expression.psi.mixins.MEArrayAccessExpressionMixin import com.intellij.lang.ASTNode @@ -46,9 +47,9 @@ return false } - val javaArray = PsiUtil.skipParenthesizedExprDown(java.arrayExpression) ?: return false - val javaIndex = PsiUtil.skipParenthesizedExprDown(java.indexExpression) ?: return false - return arrayExpr.matchesJava(javaArray, context) && indexExpr?.matchesJava(javaIndex, context) == true + val indexExpr = this.indexExpr ?: return false + return java.arrayExpression.matchesFlow(arrayExpr, context) && + java.indexExpression?.matchesFlow(indexExpr, context) == true } override fun getInputExprs() = listOfNotNull(arrayExpr, indexExpr) Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEAssignStatementImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEAssignStatementImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEAssignStatementImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -23,11 +23,11 @@ import com.demonwav.mcdev.platform.mixin.expression.MESourceMatchContext import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpression import com.demonwav.mcdev.platform.mixin.expression.gen.psi.impl.MEStatementImpl +import com.demonwav.mcdev.platform.mixin.expression.matchesFlow import com.intellij.lang.ASTNode import com.intellij.psi.JavaTokenType import com.intellij.psi.PsiAssignmentExpression import com.intellij.psi.PsiElement -import com.intellij.psi.util.PsiUtil import com.siyeh.ig.PsiReplacementUtil abstract class MEAssignStatementImplMixin(node: ASTNode) : MEStatementImpl(node) { @@ -42,10 +42,10 @@ java } - val leftJava = PsiUtil.skipParenthesizedExprDown(expandedJava.lExpression) ?: return false - val rightJava = PsiUtil.skipParenthesizedExprDown(expandedJava.rExpression) ?: return false + val rightExpr = this.rightExpr ?: return false context.fakeElementScope(isOperatorAssignment, java) { - return targetExpr.matchesJava(leftJava, context) && rightExpr?.matchesJava(rightJava, context) == true + return expandedJava.lExpression.matchesFlow(targetExpr, context) && + expandedJava.rExpression?.matchesFlow(rightExpr, context) == true } } Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEBinaryExpressionImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEBinaryExpressionImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEBinaryExpressionImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -24,6 +24,7 @@ import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpression import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpressionTypes import com.demonwav.mcdev.platform.mixin.expression.gen.psi.impl.MEExpressionImpl +import com.demonwav.mcdev.platform.mixin.expression.matchesFlow import com.demonwav.mcdev.platform.mixin.expression.psi.METypeUtil import com.demonwav.mcdev.platform.mixin.expression.psi.mixins.MEBinaryExpressionMixin import com.intellij.lang.ASTNode @@ -33,7 +34,6 @@ import com.intellij.psi.PsiInstanceOfExpression import com.intellij.psi.PsiTypeTestPattern import com.intellij.psi.tree.TokenSet -import com.intellij.psi.util.PsiUtil abstract class MEBinaryExpressionImplMixin(node: ASTNode) : MEExpressionImpl(node), MEBinaryExpressionMixin { override val operator get() = node.findChildByType(operatorTokens)!!.elementType @@ -83,9 +83,9 @@ return false } - val javaLeft = PsiUtil.skipParenthesizedExprDown(java.lOperand) ?: return false - val javaRight = PsiUtil.skipParenthesizedExprDown(java.rOperand) ?: return false - return leftExpr.matchesJava(javaLeft, context) && rightExpr?.matchesJava(javaRight, context) == true + val rightExpr = this.rightExpr ?: return false + return java.lOperand.matchesFlow(leftExpr, context) && + java.rOperand?.matchesFlow(rightExpr, context) == true } } Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEBoundReferenceExpressionImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEBoundReferenceExpressionImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEBoundReferenceExpressionImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -23,25 +23,31 @@ import com.demonwav.mcdev.platform.mixin.expression.MESourceMatchContext import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpression import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEName -import com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.QualifiedMember +import com.demonwav.mcdev.platform.mixin.expression.lmfType +import com.demonwav.mcdev.platform.mixin.expression.matchesFlow +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil import com.intellij.lang.ASTNode import com.intellij.psi.PsiElement -import com.intellij.psi.PsiMethod -import com.intellij.psi.PsiMethodReferenceExpression -import com.intellij.psi.util.PsiUtil +import com.intellij.psi.PsiMethodCallExpression +import com.llamalad7.mixinextras.expression.impl.flow.postprocessing.LMFInfo +import org.objectweb.asm.Handle abstract class MEBoundReferenceExpressionImplMixin(node: ASTNode) : MEExpressionImplMixin(node), MEExpression { override fun matchesJava(java: PsiElement, context: MESourceMatchContext): Boolean { - if (java !is PsiMethodReferenceExpression) { + if (java !is PsiMethodCallExpression) { return false } - if (java.isConstructor) { + val indyData = DesugarUtil.getIndyData(java) ?: return false + if (indyData.bsm.owner != "java/lang/invoke/LambdaMetafactory") { return false } + if (indyData.lmfType(java) != LMFInfo.Type.BOUND_METHOD) { + return false + } - val qualifier = PsiUtil.skipParenthesizedExprDown(java.qualifierExpression) ?: return false - if (!receiverExpr.matchesJava(qualifier, context)) { + val javaReceiver = java.argumentList.expressions.firstOrNull() ?: return false + if (!javaReceiver.matchesFlow(receiverExpr, context)) { return false } @@ -50,10 +56,9 @@ return true } - val method = java.resolve() as? PsiMethod ?: return false - val qualifierClass = QualifiedMember.resolveQualifier(java) ?: method.containingClass ?: return false + val implMethod = indyData.bsmArgs.getOrNull(1) as? Handle ?: return false return context.getMethods(memberName.text).any { reference -> - reference.matchMethod(method, qualifierClass) + reference.matchMethod(implMethod.owner, implMethod.name, implMethod.desc) } } Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MECastExpressionImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MECastExpressionImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MECastExpressionImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -24,6 +24,7 @@ import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpression import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEParenthesizedExpression import com.demonwav.mcdev.platform.mixin.expression.gen.psi.impl.MEExpressionImpl +import com.demonwav.mcdev.platform.mixin.expression.matchesFlow import com.demonwav.mcdev.platform.mixin.expression.psi.MEPsiUtil import com.demonwav.mcdev.platform.mixin.expression.psi.METypeUtil import com.demonwav.mcdev.platform.mixin.expression.psi.mixins.MECastExpressionMixin @@ -32,7 +33,6 @@ import com.intellij.psi.PsiInstanceOfExpression import com.intellij.psi.PsiTypeCastExpression import com.intellij.psi.PsiTypeTestPattern -import com.intellij.psi.util.PsiUtil abstract class MECastExpressionImplMixin(node: ASTNode) : MEExpressionImpl(node), MECastExpressionMixin { override val castType get() = castTypeExpr?.let(METypeUtil::convertExpressionToType) @@ -41,12 +41,12 @@ override val castedExpr get() = expressionList.lastOrNull() override fun matchesJava(java: PsiElement, context: MESourceMatchContext): Boolean { + val castedExpr = this.castedExpr ?: return false return when (java) { is PsiTypeCastExpression -> { val javaType = java.castType?.type ?: return false - val javaOperand = PsiUtil.skipParenthesizedExprDown(java.operand) ?: return false castType?.matchesJava(javaType, context) == true && - castedExpr?.matchesJava(javaOperand, context) == true + java.operand?.matchesFlow(castedExpr, context) == true } is PsiInstanceOfExpression -> { val pattern = java.pattern as? PsiTypeTestPattern Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEConstructorReferenceExpressionImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEConstructorReferenceExpressionImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEConstructorReferenceExpressionImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -23,22 +23,31 @@ import com.demonwav.mcdev.platform.mixin.expression.MESourceMatchContext import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpression import com.demonwav.mcdev.platform.mixin.expression.gen.psi.METype +import com.demonwav.mcdev.platform.mixin.expression.lmfType +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil import com.intellij.lang.ASTNode import com.intellij.psi.PsiElement -import com.intellij.psi.PsiMethodReferenceExpression +import com.intellij.psi.PsiMethodCallExpression +import com.llamalad7.mixinextras.expression.impl.flow.postprocessing.LMFInfo +import org.objectweb.asm.Handle +import org.objectweb.asm.Type abstract class MEConstructorReferenceExpressionImplMixin(node: ASTNode) : MEExpressionImplMixin(node), MEExpression { override fun matchesJava(java: PsiElement, context: MESourceMatchContext): Boolean { - if (java !is PsiMethodReferenceExpression) { + if (java !is PsiMethodCallExpression) { return false } - if (!java.isConstructor) { + val indyData = DesugarUtil.getIndyData(java) ?: return false + if (indyData.bsm.owner != "java/lang/invoke/LambdaMetafactory") { return false } + if (indyData.lmfType(java) != LMFInfo.Type.INSTANTIATION) { + return false + } - val qualifierType = java.qualifierType?.type ?: return false - return className.matchesJava(qualifierType, context) + val implMethod = indyData.bsmArgs.getOrNull(1) as? Handle ?: return false + return className.matches(Type.getObjectType(implMethod.owner), context) } override fun getInputExprs() = emptyList() Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEFreeMethodReferenceExpressionImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEFreeMethodReferenceExpressionImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEFreeMethodReferenceExpressionImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -23,24 +23,28 @@ import com.demonwav.mcdev.platform.mixin.expression.MESourceMatchContext import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpression import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEName +import com.demonwav.mcdev.platform.mixin.expression.lmfType +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil import com.intellij.lang.ASTNode -import com.intellij.psi.PsiClassType import com.intellij.psi.PsiElement -import com.intellij.psi.PsiMethod -import com.intellij.psi.PsiMethodReferenceExpression +import com.intellij.psi.PsiMethodCallExpression +import com.llamalad7.mixinextras.expression.impl.flow.postprocessing.LMFInfo +import org.objectweb.asm.Handle abstract class MEFreeMethodReferenceExpressionImplMixin(node: ASTNode) : MEExpressionImplMixin(node), MEExpression { override fun matchesJava(java: PsiElement, context: MESourceMatchContext): Boolean { - if (java !is PsiMethodReferenceExpression) { + if (java !is PsiMethodCallExpression) { return false } - if (java.isConstructor) { + val indyData = DesugarUtil.getIndyData(java) ?: return false + if (indyData.bsm.owner != "java/lang/invoke/LambdaMetafactory") { return false } + if (indyData.lmfType(java) != LMFInfo.Type.FREE_METHOD) { + return false + } - val qualifierClass = (java.qualifierType?.type as? PsiClassType)?.resolve() ?: return false - // check wildcard after checking for the qualifier class, otherwise the reference could have been qualified by // an expression. val memberName = this.memberName ?: return false @@ -48,9 +52,9 @@ return true } - val method = java.resolve() as? PsiMethod ?: return false + val implMethod = indyData.bsmArgs.getOrNull(1) as? Handle ?: return false return context.getMethods(memberName.text).any { reference -> - reference.matchMethod(method, qualifierClass) + reference.matchMethod(implMethod.owner, implMethod.name, implMethod.desc) } } Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEMemberAccessExpressionImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEMemberAccessExpressionImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEMemberAccessExpressionImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -24,6 +24,7 @@ import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpression import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEName import com.demonwav.mcdev.platform.mixin.expression.gen.psi.impl.MEExpressionImpl +import com.demonwav.mcdev.platform.mixin.expression.matchesFlow import com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.QualifiedMember import com.intellij.lang.ASTNode import com.intellij.psi.JavaPsiFacade @@ -31,7 +32,6 @@ import com.intellij.psi.PsiField import com.intellij.psi.PsiModifier import com.intellij.psi.PsiReferenceExpression -import com.intellij.psi.util.PsiUtil import com.siyeh.ig.psiutils.ExpressionUtils abstract class MEMemberAccessExpressionImplMixin(node: ASTNode) : MEExpressionImpl(node) { @@ -52,10 +52,10 @@ return false } - val javaReceiver = PsiUtil.skipParenthesizedExprDown(java.qualifierExpression) + val javaReceiver = java.qualifierExpression ?: JavaPsiFacade.getElementFactory(context.project).createExpressionFromText("this", null) context.fakeElementScope(java.qualifierExpression == null, java) { - if (!receiverExpr.matchesJava(javaReceiver, context)) { + if (!javaReceiver.matchesFlow(receiverExpr, context)) { return false } } Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEMethodCallExpressionImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEMethodCallExpressionImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEMethodCallExpressionImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -25,13 +25,14 @@ import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpression import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEName import com.demonwav.mcdev.platform.mixin.expression.gen.psi.impl.MEExpressionImpl +import com.demonwav.mcdev.platform.mixin.expression.matchesFlow +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil import com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.QualifiedMember import com.intellij.lang.ASTNode import com.intellij.psi.JavaPsiFacade import com.intellij.psi.PsiElement import com.intellij.psi.PsiMethodCallExpression import com.intellij.psi.PsiModifier -import com.intellij.psi.util.PsiUtil import com.siyeh.ig.psiutils.MethodCallUtils abstract class MEMethodCallExpressionImplMixin(node: ASTNode) : MEExpressionImpl(node) { @@ -40,6 +41,10 @@ return false } + if (DesugarUtil.isIndy(java)) { + return false + } + if (MethodCallUtils.hasSuperQualifier(java)) { return false } @@ -58,10 +63,10 @@ } } - val javaReceiver = PsiUtil.skipParenthesizedExprDown(java.methodExpression.qualifierExpression) + val javaReceiver = java.methodExpression.qualifierExpression ?: JavaPsiFacade.getElementFactory(context.project).createExpressionFromText("this", null) context.fakeElementScope(java.methodExpression.qualifierExpression == null, java.methodExpression) { - if (!receiverExpr.matchesJava(javaReceiver, context)) { + if (!javaReceiver.matchesFlow(receiverExpr, context)) { return false } } Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MENewExpressionImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MENewExpressionImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MENewExpressionImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -26,13 +26,13 @@ import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpressionTypes import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEName import com.demonwav.mcdev.platform.mixin.expression.gen.psi.impl.MEExpressionImpl +import com.demonwav.mcdev.platform.mixin.expression.matchesFlow import com.demonwav.mcdev.platform.mixin.expression.meExpressionElementFactory import com.demonwav.mcdev.platform.mixin.expression.psi.mixins.MENewExpressionMixin import com.intellij.lang.ASTNode import com.intellij.psi.PsiArrayType import com.intellij.psi.PsiElement import com.intellij.psi.PsiNewExpression -import com.intellij.psi.util.PsiUtil import com.intellij.psi.util.siblings abstract class MENewExpressionImplMixin(node: ASTNode) : MEExpressionImpl(node), MENewExpressionMixin { @@ -103,8 +103,7 @@ return false } if (!javaArrayDims.asSequence().zip(arrayDims.asSequence()).all { (javaArrayDim, arrayDim) -> - val actualJavaDim = PsiUtil.skipParenthesizedExprDown(javaArrayDim) ?: return@all false - arrayDim.matchesJava(actualJavaDim, context) + javaArrayDim.matchesFlow(arrayDim, context) } ) { return false Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEReturnStatementImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEReturnStatementImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEReturnStatementImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -23,18 +23,18 @@ import com.demonwav.mcdev.platform.mixin.expression.MESourceMatchContext import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpression import com.demonwav.mcdev.platform.mixin.expression.gen.psi.impl.MEStatementImpl +import com.demonwav.mcdev.platform.mixin.expression.matchesFlow import com.intellij.lang.ASTNode import com.intellij.psi.PsiElement import com.intellij.psi.PsiReturnStatement -import com.intellij.psi.util.PsiUtil abstract class MEReturnStatementImplMixin(node: ASTNode) : MEStatementImpl(node) { override fun matchesJava(java: PsiElement, context: MESourceMatchContext): Boolean { if (java !is PsiReturnStatement) { return false } - val javaReturnValue = PsiUtil.skipParenthesizedExprDown(java.returnValue) ?: return false - return valueExpr?.matchesJava(javaReturnValue, context) == true + val valueExpr = this.valueExpr ?: return false + return java.returnValue?.matchesFlow(valueExpr, context) == true } override fun getInputExprs() = listOfNotNull(valueExpr) Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEStaticMethodCallExpressionImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEStaticMethodCallExpressionImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEStaticMethodCallExpressionImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -24,6 +24,7 @@ import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEArguments import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEName import com.demonwav.mcdev.platform.mixin.expression.gen.psi.impl.MEExpressionImpl +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil import com.demonwav.mcdev.platform.mixin.handlers.injectionPoint.QualifiedMember import com.intellij.lang.ASTNode import com.intellij.psi.PsiElement @@ -36,6 +37,10 @@ return false } + if (DesugarUtil.isIndy(java)) { + return false + } + val method = java.resolveMethod() ?: return false if (!method.hasModifierProperty(PsiModifier.STATIC)) { return false Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/METhrowStatementImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/METhrowStatementImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/METhrowStatementImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -23,10 +23,10 @@ import com.demonwav.mcdev.platform.mixin.expression.MESourceMatchContext import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpression import com.demonwav.mcdev.platform.mixin.expression.gen.psi.impl.MEStatementImpl +import com.demonwav.mcdev.platform.mixin.expression.matchesFlow import com.intellij.lang.ASTNode import com.intellij.psi.PsiElement import com.intellij.psi.PsiThrowStatement -import com.intellij.psi.util.PsiUtil abstract class METhrowStatementImplMixin(node: ASTNode) : MEStatementImpl(node) { override fun matchesJava(java: PsiElement, context: MESourceMatchContext): Boolean { @@ -34,8 +34,8 @@ return false } - val javaException = PsiUtil.skipParenthesizedExprDown(java.exception) ?: return false - return valueExpr?.matchesJava(javaException, context) == true + val valueExpr = this.valueExpr ?: return false + return java.exception?.matchesFlow(valueExpr, context) == true } override fun getInputExprs() = listOfNotNull(valueExpr) Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/METypeImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/METypeImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/METypeImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -27,27 +27,30 @@ import com.demonwav.mcdev.util.descriptor import com.intellij.extapi.psi.ASTWrapperPsiElement import com.intellij.lang.ASTNode -import com.intellij.psi.PsiArrayType import com.intellij.psi.PsiElement import com.intellij.psi.PsiType +import org.objectweb.asm.Type abstract class METypeImplMixin(node: ASTNode) : ASTWrapperPsiElement(node), METypeMixin { override val isArray get() = findChildByType(MEExpressionTypes.TOKEN_LEFT_BRACKET) != null override val dimensions get() = findChildrenByType(MEExpressionTypes.TOKEN_LEFT_BRACKET).size - override fun matchesJava(java: PsiType, context: MESourceMatchContext): Boolean { - if (MEName.isWildcard) { - return java.arrayDimensions >= dimensions + override fun matches(type: Type, context: MESourceMatchContext): Boolean { + val inputDimensions = if (type.sort == Type.ARRAY) type.dimensions else 0 + return if (MEName.isWildcard) { + inputDimensions >= dimensions } else { - var unwrappedElementType = java - repeat(dimensions) { - unwrappedElementType = (unwrappedElementType as? PsiArrayType)?.componentType ?: return false + context.getTypes(MEName.text).any { desc -> + val fullDesc = "[".repeat(dimensions) + desc + fullDesc == type.descriptor } - val descriptor = unwrappedElementType.descriptor - return context.getTypes(MEName.text).any { it == descriptor } } } + override fun matchesJava(java: PsiType, context: MESourceMatchContext): Boolean { + return matches(Type.getType(java.descriptor), context) + } + @Suppress("PropertyName") protected abstract val MEName: MEName } Index: src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEUnaryExpressionImplMixin.kt =================================================================== --- src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEUnaryExpressionImplMixin.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/expression/psi/mixins/impl/MEUnaryExpressionImplMixin.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -24,6 +24,7 @@ import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpression import com.demonwav.mcdev.platform.mixin.expression.gen.psi.MEExpressionTypes import com.demonwav.mcdev.platform.mixin.expression.gen.psi.impl.MEExpressionImpl +import com.demonwav.mcdev.platform.mixin.expression.matchesFlow import com.demonwav.mcdev.platform.mixin.expression.psi.mixins.MEUnaryExpressionMixin import com.intellij.lang.ASTNode import com.intellij.psi.JavaTokenType @@ -56,7 +57,8 @@ return false } - return expression?.matchesJava(javaOperand, context) == true + val expression = this.expression ?: return false + return javaOperand.matchesFlow(expression, context) } override fun getInputExprs() = listOfNotNull(expression) Index: src/main/kotlin/platform/mixin/handlers/desugar/AnonymousAndLocalClassDesugarer.kt =================================================================== --- src/main/kotlin/platform/mixin/handlers/desugar/AnonymousAndLocalClassDesugarer.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/handlers/desugar/AnonymousAndLocalClassDesugarer.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -47,7 +47,6 @@ import com.intellij.psi.PsiJavaFile import com.intellij.psi.PsiMethod import com.intellij.psi.PsiMethodCallExpression -import com.intellij.psi.PsiMethodReferenceExpression import com.intellij.psi.PsiModifier import com.intellij.psi.PsiNewExpression import com.intellij.psi.PsiReferenceExpression @@ -96,7 +95,8 @@ collectUsedVariables(localClass) } - val typeParametersToCreate = calculateTypeParametersToCreate(targetClass, localClass, variableInfos) + val typeParametersToCreate = + DesugarUtil.getTypeParametersToCopy(localClass, targetClass, variableInfos.map { it.variable }) ChangeContextUtil.encodeContextInfo(localClass, false) renameReferences(project, context, localClass, variableInfos) updateLocalClassConstructors(project, localClass, variableInfos) @@ -177,35 +177,6 @@ } } - private fun calculateTypeParametersToCreate( - targetClass: PsiClass, - localClass: PsiClass, - variableInfos: Array - ): Collection { - val typeParameters = linkedSetOf() - - val visitor = object : JavaRecursiveElementWalkingVisitor() { - override fun visitReferenceElement(reference: PsiJavaCodeReferenceElement) { - super.visitReferenceElement(reference) - val resolved = reference.resolve() - if (resolved is PsiTypeParameter) { - val owner = resolved.owner - if (owner != null && !PsiTreeUtil.isAncestor(localClass, owner, false) && - !PsiTreeUtil.isAncestor(owner, targetClass, false)) { - typeParameters += resolved - } - } - } - } - - localClass.accept(visitor) - for (info in variableInfos) { - info.variable.typeElement?.accept(visitor) - } - - return typeParameters - } - private fun updateLocalClassConstructors( project: Project, localClass: PsiClass, @@ -223,13 +194,6 @@ val constructorCalls = mutableMapOf>() if (variableInfos.isNotEmpty()) { - for (reference in DesugarUtil.findReferencesInFile(localClass)) { - val methodRef = reference.element.parent as? PsiMethodReferenceExpression ?: continue - if (methodRef.isConstructor) { - DesugarUtil.desugarMethodReferenceToLambda(methodRef) - } - } - for (constructor in constructors) { for (reference in DesugarUtil.findReferencesInFile(constructor)) { var refElement = reference.element Index: src/main/kotlin/platform/mixin/handlers/desugar/DesugarContext.kt =================================================================== --- src/main/kotlin/platform/mixin/handlers/desugar/DesugarContext.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/handlers/desugar/DesugarContext.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -20,4 +20,4 @@ package com.demonwav.mcdev.platform.mixin.handlers.desugar -class DesugarContext(val classVersion: Int) +data class DesugarContext(val classVersion: Int) Index: src/main/kotlin/platform/mixin/handlers/desugar/DesugarUtil.kt =================================================================== --- src/main/kotlin/platform/mixin/handlers/desugar/DesugarUtil.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/handlers/desugar/DesugarUtil.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -20,44 +20,77 @@ package com.demonwav.mcdev.platform.mixin.handlers.desugar +import com.demonwav.mcdev.util.PsiChildPointer import com.demonwav.mcdev.util.cached import com.demonwav.mcdev.util.childrenOfType +import com.demonwav.mcdev.util.createChildPointer +import com.demonwav.mcdev.util.findContainingClass +import com.demonwav.mcdev.util.lockedCached +import com.demonwav.mcdev.util.normalize +import com.demonwav.mcdev.util.packageName import com.intellij.openapi.project.Project import com.intellij.openapi.util.Key import com.intellij.openapi.util.TextRange import com.intellij.openapi.util.UnfairTextRange +import com.intellij.openapi.util.text.StringUtil +import com.intellij.platform.workspace.storage.impl.cache.CachedValue +import com.intellij.psi.JavaPsiFacade import com.intellij.psi.JavaRecursiveElementWalkingVisitor import com.intellij.psi.PsiAnonymousClass import com.intellij.psi.PsiClass +import com.intellij.psi.PsiClassType import com.intellij.psi.PsiElement +import com.intellij.psi.PsiExpression +import com.intellij.psi.PsiJavaCodeReferenceElement import com.intellij.psi.PsiJavaFile -import com.intellij.psi.PsiLambdaExpression import com.intellij.psi.PsiMember import com.intellij.psi.PsiMethod -import com.intellij.psi.PsiMethodReferenceExpression +import com.intellij.psi.PsiMethodCallExpression +import com.intellij.psi.PsiModifier import com.intellij.psi.PsiNameIdentifierOwner +import com.intellij.psi.PsiParenthesizedExpression import com.intellij.psi.PsiReference import com.intellij.psi.PsiSubstitutor +import com.intellij.psi.PsiSuperExpression +import com.intellij.psi.PsiThisExpression +import com.intellij.psi.PsiType +import com.intellij.psi.PsiTypeCastExpression import com.intellij.psi.PsiTypeParameter +import com.intellij.psi.PsiTypes import com.intellij.psi.PsiVariable import com.intellij.psi.impl.light.LightMemberReference import com.intellij.psi.search.LocalSearchScope import com.intellij.psi.search.searches.ReferencesSearch import com.intellij.psi.util.PsiTreeUtil +import com.intellij.psi.util.PsiUtil import com.intellij.psi.util.parents -import com.intellij.refactoring.util.LambdaRefactoringUtil +import com.intellij.refactoring.util.ConflictsUtil +import com.intellij.refactoring.util.classMembers.ClassThisReferencesVisitor import com.intellij.util.JavaPsiConstructorUtil import com.intellij.util.Processor +import com.intellij.util.containers.MultiMap +import com.siyeh.ig.psiutils.PsiElementOrderComparator +import java.util.Objects import org.jetbrains.annotations.VisibleForTesting +import org.objectweb.asm.Handle +import org.objectweb.asm.Opcodes +import org.objectweb.asm.Type +import org.objectweb.asm.util.Printer object DesugarUtil { private val ORIGINAL_ELEMENT_KEY = Key.create("mcdev.desugar.originalElement") private val UNNAMED_VARIABLE_KEY = Key.create("mcdev.desugar.unnamedVariable") + private val INDY_DATA_KEY = Key.create("mcdev.desugar.indyData") + private val OTHER_FLOW_CHILDREN_KEY = Key.create>>("mcdev.desugar.otherFlowChildren") + private val VALUE_POPPED_KEY = Key.create("mcdev.desugar.valuePopped") + private val FAKE_KEY = Key.create("mcdev.desugar.fake") private val DESUGARERS = arrayOf( + MethodReferenceToLambdaDesugarer, RemoveVarArgsDesugarer, AnonymousAndLocalClassDesugarer, FieldAssignmentDesugarer, + LambdaDesugarer, ) fun getOriginalElement(desugared: PsiElement): PsiElement? { @@ -90,6 +123,49 @@ variable.putCopyableUserData(UNNAMED_VARIABLE_KEY, value) } + fun getIndyData(methodCall: PsiMethodCallExpression): IndyData? { + return methodCall.getCopyableUserData(INDY_DATA_KEY) + } + + fun isIndy(methodCall: PsiMethodCallExpression): Boolean { + return getIndyData(methodCall) != null + } + + fun setIndyData(methodCall: PsiMethodCallExpression, indyData: IndyData) { + methodCall.putCopyableUserData(INDY_DATA_KEY, indyData) + } + + fun addOtherFlowChild(expression: PsiExpression, child: PsiExpression) { + val otherFlowChildren = expression.getCopyableUserData(OTHER_FLOW_CHILDREN_KEY) + if (otherFlowChildren == null) { + expression.putCopyableUserData(OTHER_FLOW_CHILDREN_KEY, mutableListOf(expression.createChildPointer(child))) + } else { + otherFlowChildren += expression.createChildPointer(child) + } + } + + fun getOtherFlowChildren(expression: PsiExpression): List { + return expression.getCopyableUserData(OTHER_FLOW_CHILDREN_KEY) + ?.mapNotNull { it.dereference(expression) } + ?: emptyList() + } + + fun isValuePopped(expression: PsiExpression): Boolean { + return expression.getCopyableUserData(VALUE_POPPED_KEY) == true + } + + fun setValuePopped(expression: PsiExpression, value: Boolean) { + expression.putCopyableUserData(VALUE_POPPED_KEY, value) + } + + fun isFake(element: PsiElement): Boolean { + return element.getCopyableUserData(FAKE_KEY) == true + } + + fun setFake(element: PsiElement, value: Boolean) { + element.putCopyableUserData(FAKE_KEY, value) + } + fun desugar(project: Project, clazz: PsiClass, context: DesugarContext): PsiClass? { val file = clazz.containingFile as? PsiJavaFile ?: return null return file.cached { @@ -200,14 +276,232 @@ return results } - internal fun desugarMethodReferenceToLambda(methodReference: PsiMethodReferenceExpression): PsiLambdaExpression? { - val originalMethodRef = getOriginalElement(methodReference) - val lambda = LambdaRefactoringUtil.convertMethodReferenceToLambda(methodReference, false, true) - ?: return null - setOriginalElement(lambda, originalMethodRef) - for (parameter in lambda.parameterList.parameters) { - setUnnamedVariable(parameter, true) + internal fun elementNeedsThis(containingClass: PsiClass, element: PsiElement): Boolean { + val elementNeedsThis = object : ClassThisReferencesVisitor(containingClass) { + var result = false + + override fun visitExplicitThis(referencedClass: PsiClass?, reference: PsiThisExpression?) { + result = true - } + } - return lambda + + override fun visitExplicitSuper(referencedClass: PsiClass?, reference: PsiSuperExpression?) { + result = true - } + } + + override fun visitClassMemberReferenceElement( + classMember: PsiMember?, + classMemberReference: PsiJavaCodeReferenceElement? + ) { + if (classMember == null || classMember == element) { + return -} + } + if (classMember.hasModifierProperty(PsiModifier.STATIC)) { + return + } + if (classMember is PsiTypeParameter) { + return + } + result = true + } + } + + element.accept(elementNeedsThis) + return elementNeedsThis.result + } + + internal fun getTypeParametersToCopy( + sourceContext: PsiElement, + destContext: PsiElement, + capturedVariables: Iterable, + extraTypes: Iterable = emptyList(), + ): List { + val typeParameters = mutableListOf() + val addedTypeParameters = mutableSetOf() + val visitQueue = ArrayDeque() + + fun acceptTypeParameter(typeParam: PsiTypeParameter) { + val owner = typeParam.owner + if (owner != null && !PsiTreeUtil.isAncestor(sourceContext, owner, false) && + !PsiTreeUtil.isAncestor(owner, destContext, false)) { + if (addedTypeParameters.add(typeParam)) { + typeParameters += typeParam + visitQueue.add(typeParam.extendsList) + } + } + } + + val visitor = object : JavaRecursiveElementWalkingVisitor() { + override fun visitReferenceElement(reference: PsiJavaCodeReferenceElement) { + super.visitReferenceElement(reference) + val resolved = reference.resolve() + if (resolved is PsiTypeParameter) { + acceptTypeParameter(resolved) + } + } + } + + visitQueue.add(sourceContext) + for (capturedVariable in capturedVariables) { + capturedVariable.typeElement?.let(visitQueue::add) + } + for (type in extraTypes) { + if (type is PsiClassType) { + val resolved = type.resolve() + if (resolved is PsiTypeParameter) { + acceptTypeParameter(resolved) + } + } + } + while (visitQueue.isNotEmpty()) { + visitQueue.removeFirst().accept(visitor) + } + + typeParameters.sortWith(PsiElementOrderComparator.getInstance()) + return typeParameters + } + + internal fun generateMethodName(project: Project, containingClass: PsiClass, args: List): String { + val factory = JavaPsiFacade.getElementFactory(project) + + val params = args.withIndex().map { (index, type) -> factory.createParameter("p$index", type.normalize()) } + + var i = 1 + while (true) { + val templateMethod = factory.createMethod("synthetic$i", PsiTypes.voidType()) + for (param in params) { + templateMethod.parameterList.add(param) + } + val conflicts = MultiMap() + ConflictsUtil.checkMethodConflicts(containingClass, null, templateMethod, conflicts) + if (conflicts.isEmpty) { + return "synthetic$i" + } + i++ + } + } + + internal fun createNullCheck(project: Project, expression: PsiExpression, classVersion: Int): PsiExpression { + val factory = JavaPsiFacade.getElementFactory(project) + if (classVersion >= Opcodes.V9) { + val nonNullAssertion = factory.createExpressionFromText( + "java.util.Objects.requireNonNull(expr)", + expression + ) as PsiMethodCallExpression + val innerExpr = nonNullAssertion.argumentList.expressions.first().replace(expression) + as PsiExpression + addOtherFlowChild(nonNullAssertion, innerExpr) + setValuePopped(nonNullAssertion, true) + return nonNullAssertion + } else { + var getClassExpr = factory.createExpressionFromText( + "(x).getClass()", + expression + ) as PsiMethodCallExpression + val exprType = expression.type + val nonNullAssertion = if (exprType != null) { + val castExpr = factory.createExpressionFromText( + "(Type) x", + null + ) as PsiTypeCastExpression + castExpr.castType!!.replace(factory.createTypeElement(exprType)) + getClassExpr = castExpr.operand!!.replace(getClassExpr) as PsiMethodCallExpression + setFake(castExpr, true) + setValuePopped(getClassExpr, true) + castExpr + } else { + getClassExpr + } + val innerExpr = + (getClassExpr.methodExpression.qualifierExpression as PsiParenthesizedExpression).expression!!.replace(expression) + as PsiExpression + addOtherFlowChild(nonNullAssertion, innerExpr) + setValuePopped(nonNullAssertion, true) + return nonNullAssertion + } + } + + // see com.sun.tools.javac.comp.Lower.access, accReq + fun needsBridgeMethod(expression: PsiJavaCodeReferenceElement, classVersion: Int): Boolean { + // method calls with qualified super need bridge methods + if (expression is PsiMethodCallExpression) { + val qualifier = PsiUtil.skipParenthesizedExprDown(expression.methodExpression.qualifierExpression) + if (qualifier is PsiSuperExpression && qualifier.qualifier != null) { + return true + } + } + + val resolved = expression.resolve() as? PsiMember ?: return false + val resolvedClass = resolved.containingClass ?: return false + val fromClass = expression.findContainingClass() ?: return false + + if (resolvedClass == fromClass) { + return false + } + + if (classVersion <= Opcodes.V1_8 && resolved.hasModifierProperty(PsiModifier.PRIVATE)) { + return true + } + + if (resolved.hasModifierProperty(PsiModifier.PROTECTED) && + fromClass.packageName != resolvedClass.packageName && + !fromClass.isInheritor(resolvedClass, true) + ) { + return true + } + + return false + } + + class IndyData(val methodName: String, val methodDesc: String, val bsm: Handle, vararg val bsmArgs: Any) { + override fun toString(): String { + return buildString { + append("IndyData(\n") + objectToString(methodName, " ").append(",\n") + objectToString(methodDesc, " ").append(",\n") + objectToString(bsm, " ").append(",\n") + for (bsmArg in bsmArgs) { + objectToString(bsmArg, " ").append(",\n") + } + append(")") + } + } + + private fun StringBuilder.objectToString(obj: Any, indent: String): StringBuilder { + append(indent) + when (obj) { + is Float -> append(obj).append("F") + is Long -> append(obj).append("L") + is Int, is Double -> append(obj) + is String -> append('"').append(StringUtil.escapeStringCharacters(obj).replace("$", "\\$")).append('"') + is Type -> append("Type.").append(if (obj.sort == Type.METHOD) { "getMethodType" } else { "getType" }) + .append("(\"").append(StringUtil.escapeStringCharacters(obj.descriptor).replace("$", "\\$")).append("\")") + is Handle -> { + append("Handle(\n") + val innerIndent = "$indent " + append(innerIndent).append("Opcodes.").append(Printer.HANDLE_TAG[obj.tag]).append(",\n") + objectToString(obj.owner, innerIndent).append(",\n") + objectToString(obj.name, innerIndent).append(",\n") + objectToString(obj.desc, innerIndent).append(",\n") + append(innerIndent).append(obj.isInterface).append("\n") + append(indent).append(")") + } + } + return this + } + + override fun hashCode(): Int { + return Objects.hash(methodName, methodDesc, bsm, bsmArgs.contentDeepHashCode()) + } + + override fun equals(other: Any?): Boolean { + if (other !is IndyData) { + return false + } + + return methodName == other.methodName && + methodDesc == other.methodDesc && + bsm == other.bsm && + bsmArgs.contentDeepEquals(other.bsmArgs) + } + } +} Index: src/main/kotlin/platform/mixin/handlers/desugar/LambdaDesugarer.kt =================================================================== --- src/main/kotlin/platform/mixin/handlers/desugar/LambdaDesugarer.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) +++ src/main/kotlin/platform/mixin/handlers/desugar/LambdaDesugarer.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -0,0 +1,777 @@ +/* + * Minecraft Development for IntelliJ + * + * https://mcdev.io/ + * + * Copyright (C) 2025 minecraft-dev + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation, version 3.0 only. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program. If not, see . + */ + +package com.demonwav.mcdev.platform.mixin.handlers.desugar + +import com.demonwav.mcdev.util.descriptor +import com.demonwav.mcdev.util.findContainingClass +import com.demonwav.mcdev.util.findContainingMethod +import com.demonwav.mcdev.util.fullQualifiedName +import com.demonwav.mcdev.util.mapToArray +import com.demonwav.mcdev.util.normalize +import com.demonwav.mcdev.util.removeWildcards +import com.demonwav.mcdev.util.signature +import com.demonwav.mcdev.util.toObjectType +import com.intellij.openapi.project.Project +import com.intellij.openapi.util.RecursionManager +import com.intellij.openapi.util.text.StringUtil +import com.intellij.psi.CommonClassNames +import com.intellij.psi.HierarchicalMethodSignature +import com.intellij.psi.JavaPsiFacade +import com.intellij.psi.JavaRecursiveElementWalkingVisitor +import com.intellij.psi.LambdaUtil +import com.intellij.psi.PsiAnonymousClass +import com.intellij.psi.PsiArrayInitializerExpression +import com.intellij.psi.PsiArrayType +import com.intellij.psi.PsiAssignmentExpression +import com.intellij.psi.PsiCall +import com.intellij.psi.PsiClass +import com.intellij.psi.PsiClassInitializer +import com.intellij.psi.PsiClassType +import com.intellij.psi.PsiCodeBlock +import com.intellij.psi.PsiConditionalExpression +import com.intellij.psi.PsiDiamondType +import com.intellij.psi.PsiElement +import com.intellij.psi.PsiExpression +import com.intellij.psi.PsiExpressionList +import com.intellij.psi.PsiExpressionStatement +import com.intellij.psi.PsiField +import com.intellij.psi.PsiFunctionalExpression +import com.intellij.psi.PsiIfStatement +import com.intellij.psi.PsiIntersectionType +import com.intellij.psi.PsiJavaCodeReferenceElement +import com.intellij.psi.PsiJavaFile +import com.intellij.psi.PsiLambdaExpression +import com.intellij.psi.PsiLiteralExpression +import com.intellij.psi.PsiMethod +import com.intellij.psi.PsiMethodCallExpression +import com.intellij.psi.PsiMethodReferenceExpression +import com.intellij.psi.PsiModifier +import com.intellij.psi.PsiParameter +import com.intellij.psi.PsiParenthesizedExpression +import com.intellij.psi.PsiReferenceExpression +import com.intellij.psi.PsiResolveHelper +import com.intellij.psi.PsiReturnStatement +import com.intellij.psi.PsiSubstitutor +import com.intellij.psi.PsiSuperExpression +import com.intellij.psi.PsiSwitchExpression +import com.intellij.psi.PsiSwitchLabelStatement +import com.intellij.psi.PsiSwitchStatement +import com.intellij.psi.PsiType +import com.intellij.psi.PsiTypeCastExpression +import com.intellij.psi.PsiTypes +import com.intellij.psi.PsiVariable +import com.intellij.psi.codeStyle.VariableKind +import com.intellij.psi.infos.MethodCandidateInfo +import com.intellij.psi.util.MethodSignature +import com.intellij.psi.util.MethodSignatureUtil +import com.intellij.psi.util.PsiSuperMethodUtil +import com.intellij.psi.util.PsiTreeUtil +import com.intellij.psi.util.PsiTypesUtil +import com.intellij.psi.util.PsiUtil +import com.intellij.psi.util.TypeConversionUtil +import com.intellij.psi.util.parentOfType +import com.siyeh.ig.psiutils.SwitchUtils +import com.siyeh.ig.psiutils.VariableNameGenerator +import java.lang.invoke.LambdaMetafactory +import org.objectweb.asm.Handle +import org.objectweb.asm.Opcodes +import org.objectweb.asm.Type + +object LambdaDesugarer : Desugarer() { + private const val METAFACTORY_DESC = "(Ljava/lang/invoke/MethodHandles\$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;" + private const val ALT_METAFACTORY_DESC = "(Ljava/lang/invoke/MethodHandles\$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;" + + override fun desugar(project: Project, file: PsiJavaFile, context: DesugarContext) { + val expressionsToDesugar = mutableListOf() + val lambdaNames = mutableMapOf() + val lambdaCounts = mutableMapOf, Int>() + + file.accept(object : JavaRecursiveElementWalkingVisitor() { + override fun visitLambdaExpression(expression: PsiLambdaExpression) { + super.visitLambdaExpression(expression) + + val containingMethod = PsiTreeUtil.getParentOfType(expression, PsiMethod::class.java, PsiClassInitializer::class.java) + val containingMethodName = when (containingMethod) { + is PsiMethod -> if (containingMethod.isConstructor) "new" else containingMethod.name + is PsiClassInitializer -> if (containingMethod.hasModifierProperty(PsiModifier.STATIC)) "static" else "new" + else -> return + } + val containingClass = containingMethod.containingClass ?: return + + val functionalInterfaceType = getFunctionalInterfaceType(expression) + val serializableLambda = functionalInterfaceType != null && isSerializable(project, functionalInterfaceType) + + var lambdaMethodName = "lambda$$containingMethodName" + if (serializableLambda) { + val lambdaDisambiguation = getLambdaDisambiguation(expression, functionalInterfaceType) + lambdaMethodName += "$${Integer.toHexString(lambdaDisambiguation.hashCode())}" + } + + // for some reason serializable lambda counters start at 1 while others start at 0 + val lambdaCount = lambdaCounts[containingClass to lambdaMethodName] ?: if (serializableLambda) 1 else 0 + lambdaCounts[containingClass to lambdaMethodName] = lambdaCount + 1 + + lambdaNames[expression] = "$lambdaMethodName$$lambdaCount" + expressionsToDesugar += expression + } + + override fun elementFinished(element: PsiElement) { + if (element is PsiMethodReferenceExpression) { + expressionsToDesugar += element + } + } + }) + + val serializationDatas = mutableMapOf>>() + for (expr in expressionsToDesugar.asReversed()) { + val containingClass = expr.findContainingClass() ?: continue + val serializationData = when (expr) { + is PsiLambdaExpression -> desugarLambda(project, expr, lambdaNames[expr]!!) + is PsiMethodReferenceExpression -> desugarMethodReference(project, expr, context) + else -> throw IllegalStateException("Can only desugar lambdas and method references") + } + if (serializationData != null) { + serializationDatas.getOrPut(containingClass) { linkedMapOf() } + .getOrPut(serializationData.implMethodName) { mutableListOf() } += serializationData + } + } + + for ((containingClass, datas) in serializationDatas) { + if (datas.isNotEmpty()) { + containingClass.add(createDeserializeLambdaMethod(project, datas)) + } + } + } + + private fun getLambdaDisambiguation(lambda: PsiLambdaExpression, functionalInterfaceType: PsiType): String { + val functionalInterface = + (normalizeFunctionalInterfaceType(functionalInterfaceType) as? PsiClassType)?.resolve() ?: return "" + val originalLambda = DesugarUtil.getOriginalElement(lambda) ?: lambda + val method = originalLambda.findContainingMethod() + + val builder = StringBuilder() + if (method != null) { + // the throws declaration seems to be missing from the lambda disambiguation signature for some reason + builder.append(method.signature.substringBefore('^')).append(':') + } else { + // class initializers, field initializers, etc, seem to have this signature, even when the only constructor is otherwise + builder.append("()V:") + } + + builder.append(functionalInterface.fullQualifiedName).append(' ') + + val varDecl = PsiTreeUtil.getParentOfType(lambda, PsiVariable::class.java, true, PsiClass::class.java) + if (varDecl != null) { + builder.append(varDecl.name).append('=') + } + + val capturedVariables = getCapturedVariables(lambda, getMethodReferenceQualifier(lambda)) + for (capturedVar in capturedVariables) { + builder.append(capturedVar.type.signature).append(' ').append(capturedVar.name).append(',') + } + + return builder.toString() + } + + private fun desugarLambda(project: Project, lambda: PsiLambdaExpression, name: String): LambdaSerializationData? { + val lambdaBody = lambda.body ?: return null + val containingClass = lambda.findContainingClass() ?: return null + val functionalInterfaceType = getFunctionalInterfaceType(lambda) ?: return null + val functionalInterfaceMethods = getFunctionalInterfaceMethods(functionalInterfaceType).ifEmpty { return null } + val (functionalInterfaceMethod, functionalInterfaceSignature) = functionalInterfaceMethods.first() + val functionalReturnType = + functionalInterfaceSignature.substitutor.substitute(functionalInterfaceMethod.returnType) ?: return null + + val methodReferenceQualifier = getMethodReferenceQualifier(lambda) + + val capturesThis = DesugarUtil.elementNeedsThis(containingClass, lambda) + val capturedVariables = getCapturedVariables(lambda, methodReferenceQualifier) + val extraTypesToCheckForTypeParams = mutableListOf() + + val paramsToAdd = mutableListOf() + + val factory = JavaPsiFacade.getElementFactory(project) + + if (methodReferenceQualifier != null) { + val selfType = methodReferenceQualifier.type ?: return null + val selfName = VariableNameGenerator(lambdaBody, VariableKind.PARAMETER) + .byName("self") + .generate(true) + val newParam = factory.createParameter(selfName, selfType) + DesugarUtil.setUnnamedVariable(newParam, true) + paramsToAdd += newParam + } + + for (capturedVar in capturedVariables) { + val name = capturedVar.name ?: continue + paramsToAdd += factory.createParameter(name, capturedVar.type) + } + + if (lambda.hasFormalParameterTypes()) { + paramsToAdd += lambda.parameterList.parameters + for (param in lambda.parameterList.parameters) { + extraTypesToCheckForTypeParams += param.type + } + } else { + for ((param, paramType) in lambda.parameterList.parameters.zip(functionalInterfaceSignature.parameterTypes)) { + extraTypesToCheckForTypeParams += paramType + val newParam = factory.createParameter(param.name, paramType.removeWildcards()) + DesugarUtil.setOriginalElement(newParam, DesugarUtil.getOriginalElement(param)) + DesugarUtil.setUnnamedVariable(newParam, DesugarUtil.isUnnamedVariable(param)) + paramsToAdd += newParam + } + } + val typeParametersToCopy = DesugarUtil.getTypeParametersToCopy( + lambda, + containingClass, + capturedVariables, + extraTypesToCheckForTypeParams + ) + + val lambdaMethodText = buildString { + append("private ") + if (!capturesThis) { + append("static ") + } + if (typeParametersToCopy.isNotEmpty()) { + append(" ") + } + append("void ") + append(name) + append("() {}") + } + + val lambdaMethod = containingClass.add(factory.createMethodFromText(lambdaMethodText, null)) + as PsiMethod + DesugarUtil.setOriginalElement(lambdaMethod, DesugarUtil.getOriginalElement(lambda)) + lambdaMethod.returnTypeElement!!.replace(factory.createTypeElement(functionalReturnType.removeWildcards())) + + if (typeParametersToCopy.isNotEmpty()) { + val typeParamList = lambdaMethod.typeParameterList!! + for (typeParam in typeParametersToCopy) { + typeParamList.add(typeParam) + } + typeParamList.typeParameters.first().delete() // delete the T + } + + val paramList = lambdaMethod.parameterList + for (param in paramsToAdd) { + paramList.add(param) + } + + val capturedTypes = mutableListOf() + if (capturesThis) { + capturedTypes += factory.createType(containingClass) + } else { + methodReferenceQualifier?.type?.let { capturedTypes += it } + } + capturedVariables.mapTo(capturedTypes) { it.type } + val (indyCall, serializationData) = createIndyCall( + project, + containingClass, + lambda, + capturedTypes, + lambdaMethod + ) ?: return null + if (capturesThis) { + indyCall.argumentList.add(factory.createExpressionFromText("this", lambda)) + } else if (methodReferenceQualifier != null) { + indyCall.argumentList.add(methodReferenceQualifier) + } + for (variable in capturedVariables) { + val varName = variable.name ?: continue + indyCall.argumentList.add(factory.createExpressionFromText(varName, lambda)) + } + + methodReferenceQualifier?.replace(factory.createExpressionFromText(lambdaMethod.parameterList.parameters.first().name, null)) + when (lambdaBody) { + is PsiCodeBlock -> lambdaMethod.body!!.replace(lambdaBody) + is PsiExpression -> { + val statementToAdd = if (functionalReturnType == PsiTypes.voidType()) { + val stmt = factory.createStatementFromText("foo();", null) as PsiExpressionStatement + stmt.expression.replace(lambdaBody) + stmt + } else { + val stmt = factory.createStatementFromText("return x;", null) as PsiReturnStatement + stmt.returnValue!!.replace(lambdaBody) + stmt + } + lambdaMethod.body!!.add(statementToAdd) + } + else -> throw IllegalStateException("Lambda body must be PsiCodeBlock or PsiExpression") + } + + lambda.replace(indyCall) + + return serializationData + } + + private fun getCapturedVariables( + lambda: PsiLambdaExpression, + methodReferenceQualifier: PsiExpression? + ): Collection { + val capturedVariables = linkedSetOf() + + lambda.accept(object : JavaRecursiveElementWalkingVisitor() { + override fun visitReferenceElement(reference: PsiJavaCodeReferenceElement) { + doVisitReferenceElement(reference) + } + + override fun visitReferenceExpression(expression: PsiReferenceExpression) { + doVisitReferenceElement(expression) + } + + fun doVisitReferenceElement(element: PsiJavaCodeReferenceElement) { + val variable = element.resolve() as? PsiVariable ?: return + if (variable is PsiField) { + return + } + if (PsiTreeUtil.isAncestor(methodReferenceQualifier, element, false)) { + return + } + if (PsiTreeUtil.isAncestor(lambda, variable, true)) { + return + } + capturedVariables += variable + } + }) + + return capturedVariables + } + + private fun getMethodReferenceQualifier(lambda: PsiLambdaExpression): PsiExpression? { + if (!MethodReferenceToLambdaDesugarer.isQualifiedMethodReference(lambda)) { + return null + } + + return (LambdaUtil.extractSingleExpressionFromBody(lambda.body) as? PsiMethodCallExpression) + ?.methodExpression + ?.qualifierExpression + ?.takeUnless { it is PsiSuperExpression } + } + + private fun desugarMethodReference( + project: Project, + methodReference: PsiMethodReferenceExpression, + context: DesugarContext + ): LambdaSerializationData? { + val qualifier = methodReference.qualifier ?: return null + val boundInstance = (qualifier as? PsiExpression)?.takeIf { + it !is PsiReferenceExpression || it.resolve() !is PsiClass + } + val calledMethod = methodReference.resolve() ?: return null + val containingClass = methodReference.findContainingClass() ?: return null + + val (indyCall, serializationData) = createIndyCall( + project, + containingClass, + methodReference, + listOfNotNull(boundInstance?.type), + calledMethod + ) ?: return null + if (boundInstance != null) { + indyCall.argumentList.add(DesugarUtil.createNullCheck(project, boundInstance, context.classVersion)) + } + methodReference.replace(indyCall) + + return serializationData + } + + private fun createIndyCall( + project: Project, + containingClass: PsiClass, + functionalExpr: PsiFunctionalExpression, + capturedTypes: List, + implElement: PsiElement, + ): Pair? { + val containingClassName = containingClass.name ?: return null + val implClass = implElement.parentOfType(withSelf = true) ?: return null + val implMethod = implElement as? PsiMethod + val implMethodDesc = implMethod?.descriptor ?: "()V" + val functionalInterfaceType = getFunctionalInterfaceType(functionalExpr) ?: return null + val functionalInterfaceMethods = getFunctionalInterfaceMethods(functionalInterfaceType).ifEmpty { return null } + val normalizedFunctionalInterfaceType = normalizeFunctionalInterfaceType(functionalInterfaceType) + val (functionalInterfaceMethod, functionalInterfaceMethodSignature) = functionalInterfaceMethods.first() + val functionalInterfaceMethodDesc = functionalInterfaceMethod.descriptor ?: return null + val instantiatedMethodDesc = Type.getMethodType( + Type.getType(functionalInterfaceMethodSignature.substitutor.substitute(functionalInterfaceMethod.returnType).descriptor), + *functionalInterfaceMethodSignature.parameterTypes.mapToArray { Type.getType(it.descriptor) } + ) + + val serializableLambda = isSerializable(project, functionalInterfaceType) + val needsAltMetafactory = functionalInterfaceType is PsiIntersectionType || functionalInterfaceMethods.size > 1 || serializableLambda + + val implMethodHandle = Handle( + when { + implMethod?.isConstructor != false -> Opcodes.H_NEWINVOKESPECIAL + implMethod.hasModifierProperty(PsiModifier.STATIC) -> Opcodes.H_INVOKESTATIC + implClass.isInterface -> Opcodes.H_INVOKEINTERFACE + else -> Opcodes.H_INVOKEVIRTUAL + }, + implClass.fullQualifiedName?.replace('.', '/') ?: return null, + if (implMethod?.isConstructor != false) "" else implMethod.name, + implMethodDesc, + implClass.isInterface + ) + val bsmArgs = mutableListOf( + Type.getMethodType(functionalInterfaceMethodDesc), + implMethodHandle, + instantiatedMethodDesc + ) + + if (needsAltMetafactory) { + val intersectionTypes = if (functionalInterfaceType is PsiIntersectionType) { + functionalInterfaceType.conjuncts.filter { PsiUtil.resolveClassInType(it)?.qualifiedName != "java.io.Serializable" } + } else { + listOf(functionalInterfaceType) + } + val markerInterfaces = intersectionTypes.filter { it != normalizedFunctionalInterfaceType } + + var flags = LambdaMetafactory.FLAG_BRIDGES // bridges is always set for some reason + if (serializableLambda) { + flags = flags or LambdaMetafactory.FLAG_SERIALIZABLE + } + if (markerInterfaces.isNotEmpty()) { + flags = flags or LambdaMetafactory.FLAG_MARKERS + } + bsmArgs += flags + + if (markerInterfaces.isNotEmpty()) { + bsmArgs += markerInterfaces.size + markerInterfaces.mapTo(bsmArgs) { Type.getType(it.descriptor) } + } + + bsmArgs += functionalInterfaceMethods.size - 1 + for ((method, _) in functionalInterfaceMethods.drop(1)) { + val desc = method.descriptor ?: continue + bsmArgs += Type.getType(desc) + } + } + + val methodDesc = + "(${capturedTypes.joinToString("") { it.descriptor }})${normalizedFunctionalInterfaceType.descriptor}" + val indyData = DesugarUtil.IndyData( + functionalInterfaceMethod.name, + methodDesc, + Handle( + Opcodes.H_INVOKESTATIC, + "java/lang/invoke/LambdaMetafactory", + if (needsAltMetafactory) "altMetafactory" else "metafactory", + if (needsAltMetafactory) ALT_METAFACTORY_DESC else METAFACTORY_DESC, + false + ), + *bsmArgs.toTypedArray() + ) + + val factory = JavaPsiFacade.getElementFactory(project) + val methodName = DesugarUtil.generateMethodName(project, containingClass, capturedTypes) + + val createdMethod = factory.createMethodFromText("private static void $methodName() {}", null) + DesugarUtil.setFake(createdMethod, true) + createdMethod.returnTypeElement!!.replace(factory.createTypeElement(normalizedFunctionalInterfaceType.normalize())) + for ((i, paramType) in capturedTypes.withIndex()) { + createdMethod.parameterList.add(factory.createParameter("param${i + 1}", paramType.normalize())) + } + containingClass.add(createdMethod) + + val indyCall = factory.createExpressionFromText("$containingClassName.$methodName()", functionalExpr) + as PsiMethodCallExpression + DesugarUtil.setOriginalElement(indyCall, DesugarUtil.getOriginalElement(functionalExpr)) + DesugarUtil.setIndyData(indyCall, indyData) + + val serializationData = if (serializableLambda) { + LambdaSerializationData( + implMethodHandle.name, + implMethodHandle.tag, + Type.getType(normalizedFunctionalInterfaceType.descriptor).internalName, + functionalInterfaceMethod.name, + functionalInterfaceMethodDesc, + implMethodHandle.owner, + implMethodHandle.desc, + methodName, + capturedTypes, + indyData, + ) + } else { + null + } + + return indyCall to serializationData + } + + private fun getFunctionalInterfaceType(expression: PsiElement): PsiType? { + var parent = expression.parent + var element = expression + + while (parent is PsiParenthesizedExpression || parent is PsiConditionalExpression) { + if (parent is PsiConditionalExpression && parent.condition == element) { + return PsiTypes.booleanType() + } + element = parent + parent = parent.parent + } + + when (parent) { + is PsiArrayInitializerExpression -> { + val type = parent.type + if (type is PsiArrayType) { + return type.componentType + } + } + is PsiTypeCastExpression -> return parent.castType?.type + is PsiVariable -> return parent.type + is PsiAssignmentExpression -> { + if (expression is PsiExpression && !PsiUtil.isOnAssignmentLeftHand(expression)) { + return parent.lExpression.type + } + } + is PsiExpressionList -> { + val lambdaIdx = LambdaUtil.getLambdaIdx(parent, expression) + if (lambdaIdx >= 0) { + var granny = parent.parent + if (granny is PsiAnonymousClass) { + granny = granny.parent + } + if (granny is PsiCall) { + val resolveResult = PsiDiamondType.getDiamondsAwareResolveResult(granny) + val resolved = resolveResult.element + if (resolved is PsiMethod) { + val parameters = resolved.parameterList.parameters + val finalLambdaIdx = adjustLambdaIdx(lambdaIdx, resolved, parameters) + if (finalLambdaIdx < parameters.size) { + return PsiResolveHelper.ourGraphGuard.doPreventingRecursion(expression, !MethodCandidateInfo.isOverloadCheck()) { + resolveResult.substitutor.substitute(parameters[finalLambdaIdx].type) + } + } + } + } + } + } + is PsiReturnStatement -> return PsiTypesUtil.getMethodReturnType(parent) + is PsiLambdaExpression -> return LambdaUtil.getFunctionalInterfaceType(expression, true) + else -> { + val switchExpression = element.parentOfType() + if (switchExpression != null && expression in PsiUtil.getSwitchResultExpressions(switchExpression)) { + return getFunctionalInterfaceType(switchExpression) + } + } + } + + return null + } + + private fun adjustLambdaIdx(lambdaIdx: Int, resolved: PsiMethod, parameters: Array): Int { + return if (resolved.isVarArgs && lambdaIdx >= parameters.size) { + parameters.size - 1 + } else { + lambdaIdx + } + } + + private fun normalizeFunctionalInterfaceType(functionalInterfaceType: PsiType): PsiType { + return if (functionalInterfaceType is PsiIntersectionType) { + functionalInterfaceType.conjuncts.firstOrNull { + getFunctionalInterfaceMethods(it).isNotEmpty() + } ?: functionalInterfaceType + } else { + functionalInterfaceType + } + } + + // see com.sun.tools.javac.code.Types.functionalInterfaceBridges + private fun getFunctionalInterfaceMethods(functionalInterfaceType: PsiType): List> { + val types = if (functionalInterfaceType is PsiIntersectionType) { + functionalInterfaceType.conjuncts + } else { + arrayOf(functionalInterfaceType) + } + + val result = mutableListOf>() + for (type in types) { + val resolveResult = PsiUtil.resolveGenericsClassInType(type) + val clazz = resolveResult.element ?: continue + + outer@ + for (method in getAllAbstractMethods(clazz)) { + val substitutor = LambdaUtil.getSubstitutor(method, resolveResult) + val erasedSignature = method.getSignature(PsiSubstitutor.EMPTY) + val signature = method.getSignature(substitutor) + for ((existingMethod, existingSignature) in result) { + if (existingMethod.name != method.name) { + return emptyList() + } + if (!MethodSignatureUtil.areErasedParametersEqual(existingSignature, signature)) { + return emptyList() + } + + val erasedReturnType = method.returnType?.let(TypeConversionUtil::erasure) + val existingErasedReturnType = existingMethod.returnType?.let(TypeConversionUtil::erasure) + val existingErasedSignature = existingMethod.getSignature(PsiSubstitutor.EMPTY) + if (erasedReturnType == existingErasedReturnType && MethodSignatureUtil.areErasedParametersEqual(existingErasedSignature, erasedSignature)) { + continue@outer + } + } + + result += method to signature + } + } + + return result + } + + private fun getAllAbstractMethods(clazz: PsiClass): List { + if (!clazz.isInterface || clazz.isAnnotationType) { + return emptyList() + } + + return RecursionManager.doPreventingRecursion(clazz, true) { + val abstractMethods = mutableListOf() + val defaultMethods = mutableListOf() + for (method in clazz.methods) { + if (method.hasModifierProperty(PsiModifier.STATIC)) { + continue + } + if (method.hasModifierProperty(PsiModifier.ABSTRACT)) { + if (!overridesPublicObjectMethod(method.hierarchicalMethodSignature)) { + abstractMethods += method + } + } else { + defaultMethods += method + } + } + + for (superInterface in clazz.interfaces) { + getAllAbstractMethods(superInterface).filterTo(abstractMethods) { method -> + defaultMethods.none { defaultMethod -> + defaultMethod.name == method.name && PsiSuperMethodUtil.isSuperMethod(defaultMethod, method) + } + } + } + + abstractMethods + } ?: emptyList() + } + + private fun overridesPublicObjectMethod(methodSig: HierarchicalMethodSignature): Boolean { + val superSigs = methodSig.superSignatures + + return if (superSigs.isEmpty()) { + val method = methodSig.method + if (method.containingClass?.qualifiedName == CommonClassNames.JAVA_LANG_OBJECT) { + if (method.hasModifierProperty(PsiModifier.PUBLIC)) { + return true + } + } + + false + } else { + superSigs.any(::overridesPublicObjectMethod) + } + } + + private fun isSerializable(project: Project, type: PsiType): Boolean { + return JavaPsiFacade.getElementFactory(project) + .createTypeByFQClassName("java.io.Serializable") + .isAssignableFrom(type) + } + + private fun createDeserializeLambdaMethod( + project: Project, + serializationDatas: Map> + ): PsiMethod { + val methodText = buildString { + append("private static java.lang.Object \$deserializeLambda$(java.lang.invoke.SerializedLambda serializedLambda) {") + append("switch (serializedLambda.getImplMethodName()) {") + // we collected the serialization datas in reverse, so iterate in reverse again here so the cases appear in + // the correct order. + for (implMethodName in serializationDatas.keys.toList().asReversed()) { + append("case \"") + append(StringUtil.escapeStringCharacters(implMethodName)) + append("\":") + append("break;") + } + append("}") + append("throw new java.lang.IllegalArgumentException(\"Invalid lambda deserialization\");") + append("}") + } + + val factory = JavaPsiFacade.getElementFactory(project) + val method = factory.createMethodFromText(methodText, null) + DesugarUtil.setUnnamedVariable(method.parameterList.parameters.first(), true) + val switchStatement = method.body!!.statements.first() as PsiSwitchStatement + for (switchBranch in SwitchUtils.getSwitchBranches(switchStatement)) { + val branchValue = (switchBranch as PsiLiteralExpression).value as String + // PsiLiteralExpression -> PsiCaseLabelElementList -> PsiSwitchLabelStatement + val caseLabel = switchBranch.parent.parent as PsiSwitchLabelStatement + // don't iterate in reverse here like we do above; we are adding the if statements *after* the case label, + // which already has the effect of reversing them. + for (data in serializationDatas[branchValue]!!) { + val ifStatementText = buildString { + append("if (serializedLambda.getImplMethodKind() == ") + append(data.implMethodKind) + append(" && serializedLambda.getFunctionalInterfaceClass().equals(\"") + append(StringUtil.escapeStringCharacters(data.functionalInterfaceClass)) + append("\") && serializedLambda.getFunctionalInterfaceMethodName().equals(\"") + append(StringUtil.escapeStringCharacters(data.functionalInterfaceMethodName)) + append("\") && serializedLambda.getFunctionalInterfaceMethodSignature().equals(\"") + append(StringUtil.escapeStringCharacters(data.functionalInterfaceMethodSignature)) + append("\") && serializedLambda.getImplClass().equals(\"") + append(StringUtil.escapeStringCharacters(data.implClass)) + append("\") && serializedLambda.getImplMethodSignature().equals(\"") + append(StringUtil.escapeStringCharacters(data.implMethodSignature)) + append("\"))") + append(data.indyMethodName) + append("();") + } + val ifStatement = factory.createStatementFromText(ifStatementText, caseLabel) as PsiIfStatement + val indyCall = (ifStatement.thenBranch as PsiExpressionStatement).expression as PsiMethodCallExpression + DesugarUtil.setIndyData(indyCall, data.indyData) + for ((i, captureType) in data.captureTypes.withIndex()) { + val captureExpr = if ((captureType as? PsiClassType)?.resolve()?.qualifiedName == "java.lang.Object") { + factory.createExpressionFromText("serializedLambda.getCapturedArg($i)", caseLabel) + } else { + val castExpr = factory.createExpressionFromText("(x) serializedLambda.getCapturedArg($i)", caseLabel) + as PsiTypeCastExpression + castExpr.castType!!.replace(factory.createTypeElement(captureType.toObjectType(project))) + castExpr + } + indyCall.argumentList.add(captureExpr) + } + switchStatement.body!!.addAfter(ifStatement, caseLabel) + } + } + + return method + } + + private class LambdaSerializationData( + val implMethodName: String, + val implMethodKind: Int, + val functionalInterfaceClass: String, + val functionalInterfaceMethodName: String, + val functionalInterfaceMethodSignature: String, + val implClass: String, + val implMethodSignature: String, + val indyMethodName: String, + val captureTypes: List, + val indyData: DesugarUtil.IndyData, + ) +} Index: src/main/kotlin/platform/mixin/handlers/desugar/MethodReferenceToLambdaDesugarer.kt =================================================================== --- src/main/kotlin/platform/mixin/handlers/desugar/MethodReferenceToLambdaDesugarer.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) +++ src/main/kotlin/platform/mixin/handlers/desugar/MethodReferenceToLambdaDesugarer.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -0,0 +1,78 @@ +/* + * Minecraft Development for IntelliJ + * + * https://mcdev.io/ + * + * Copyright (C) 2025 minecraft-dev + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation, version 3.0 only. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program. If not, see . + */ + +package com.demonwav.mcdev.platform.mixin.handlers.desugar + +import com.demonwav.mcdev.util.childrenOfType +import com.demonwav.mcdev.util.hasSyntheticMethod +import com.intellij.openapi.project.Project +import com.intellij.openapi.util.Key +import com.intellij.psi.LambdaUtil +import com.intellij.psi.PsiClass +import com.intellij.psi.PsiJavaFile +import com.intellij.psi.PsiLambdaExpression +import com.intellij.psi.PsiMethodCallExpression +import com.intellij.psi.PsiMethodReferenceExpression +import com.intellij.psi.PsiReferenceExpression +import com.intellij.refactoring.util.LambdaRefactoringUtil + +object MethodReferenceToLambdaDesugarer : Desugarer() { + private val QUALIFIED_METHOD_REFERENCE_KEY = Key.create("mcdev.desugar.methodReferenceToLambda.qualifiedMethodReference") + + fun isQualifiedMethodReference(lambda: PsiLambdaExpression): Boolean { + return lambda.getCopyableUserData(QUALIFIED_METHOD_REFERENCE_KEY) == true + } + + override fun desugar(project: Project, file: PsiJavaFile, context: DesugarContext) { + for (methodRef in file.childrenOfType()) { + if (methodRef.hasSyntheticMethod(context.classVersion)) { + desugarMethodReferenceToLambda(methodRef) + } + } + } + + private fun desugarMethodReferenceToLambda(methodReference: PsiMethodReferenceExpression): PsiLambdaExpression? { + val qualifierExpression = methodReference.qualifierExpression?.takeIf { + it !is PsiReferenceExpression || it.resolve() !is PsiClass + }?.copy() + val originalMethodRef = DesugarUtil.getOriginalElement(methodReference) + val originalMethodName = methodReference.referenceNameElement?.let(DesugarUtil::getOriginalElement) + + val lambda = LambdaRefactoringUtil.convertMethodReferenceToLambda(methodReference, false, true) + ?: return null + + DesugarUtil.setOriginalElement(lambda, originalMethodRef) + lambda.putCopyableUserData(QUALIFIED_METHOD_REFERENCE_KEY, qualifierExpression != null) + for (parameter in lambda.parameterList.parameters) { + DesugarUtil.setUnnamedVariable(parameter, true) + } + + // convertMethodReferenceToLambda creates lambdas from text which loses the original elements. Add them back here + val methodCall = LambdaUtil.extractSingleExpressionFromBody(lambda.body) as? PsiMethodCallExpression + if (methodCall != null) { + methodCall.methodExpression.referenceNameElement?.let { DesugarUtil.setOriginalElement(it, originalMethodName) } + if (qualifierExpression != null) { + methodCall.methodExpression.qualifierExpression?.replace(qualifierExpression) + } + } + + return lambda + } +} Index: src/main/kotlin/platform/mixin/handlers/injectionPoint/ConstantStringMethodInjectionPoint.kt =================================================================== --- src/main/kotlin/platform/mixin/handlers/injectionPoint/ConstantStringMethodInjectionPoint.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/handlers/injectionPoint/ConstantStringMethodInjectionPoint.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -21,6 +21,7 @@ package com.demonwav.mcdev.platform.mixin.handlers.injectionPoint import com.demonwav.mcdev.platform.mixin.handlers.MixinAnnotationHandler +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil import com.demonwav.mcdev.platform.mixin.reference.MixinSelector import com.demonwav.mcdev.platform.mixin.util.MethodTargetMember import com.demonwav.mcdev.platform.mixin.util.fakeResolve @@ -159,6 +160,10 @@ private val ldc: String?, ) : NavigationVisitor() { private fun isConstantStringMethodCall(expression: PsiMethodCallExpression): Boolean { + if (DesugarUtil.isIndy(expression)) { + return false + } + // Must return void if (expression.type != PsiTypes.voidType()) { return false @@ -182,6 +187,8 @@ } override fun visitMethodCallExpression(expression: PsiMethodCallExpression) { + super.visitMethodCallExpression(expression) + if (isConstantStringMethodCall(expression)) { expression.resolveMethod()?.let { method -> val matches = selector.matchMethod( @@ -193,8 +200,6 @@ } } } - - super.visitMethodCallExpression(expression) } } Index: src/main/kotlin/platform/mixin/handlers/injectionPoint/InjectionPoint.kt =================================================================== --- src/main/kotlin/platform/mixin/handlers/injectionPoint/InjectionPoint.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/handlers/injectionPoint/InjectionPoint.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -20,6 +20,7 @@ package com.demonwav.mcdev.platform.mixin.handlers.injectionPoint +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil import com.demonwav.mcdev.platform.mixin.reference.MixinSelector import com.demonwav.mcdev.platform.mixin.reference.toMixinString import com.demonwav.mcdev.platform.mixin.util.InjectionPointSpecifier @@ -317,8 +318,10 @@ val result = mutableListOf() protected fun addResult(element: PsiElement) { + if (!DesugarUtil.isFake(element)) { - result += element - } + result += element + } + } open fun configureBytecodeTarget(classNode: ClassNode, methodNode: MethodNode) { } Index: src/main/kotlin/platform/mixin/handlers/injectionPoint/InvokeAssignInjectionPoint.kt =================================================================== --- src/main/kotlin/platform/mixin/handlers/injectionPoint/InvokeAssignInjectionPoint.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/handlers/injectionPoint/InvokeAssignInjectionPoint.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -20,6 +20,7 @@ package com.demonwav.mcdev.platform.mixin.handlers.injectionPoint +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil import com.demonwav.mcdev.platform.mixin.reference.MixinSelector import com.demonwav.mcdev.util.MemberReference import com.intellij.openapi.editor.Editor @@ -133,6 +134,12 @@ } override fun visitMethodCallExpression(expression: PsiMethodCallExpression) { + super.visitMethodCallExpression(expression) + + if (DesugarUtil.isIndy(expression)) { + return + } + val method = expression.resolveMethod() if (method != null) { val containingClass = method.containingClass @@ -150,8 +157,6 @@ visitMethodUsage(method, qualifier, expression) } - - super.visitMethodCallExpression(expression) } } Index: src/main/kotlin/platform/mixin/handlers/injectionPoint/InvokeInjectionPoint.kt =================================================================== --- src/main/kotlin/platform/mixin/handlers/injectionPoint/InvokeInjectionPoint.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/handlers/injectionPoint/InvokeInjectionPoint.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -20,6 +20,7 @@ package com.demonwav.mcdev.platform.mixin.handlers.injectionPoint +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil import com.demonwav.mcdev.platform.mixin.reference.MixinSelector import com.demonwav.mcdev.util.MemberReference import com.intellij.openapi.editor.Editor @@ -81,6 +82,12 @@ } override fun visitMethodCallExpression(expression: PsiMethodCallExpression) { + super.visitMethodCallExpression(expression) + + if (DesugarUtil.isIndy(expression)) { + return + } + val method = expression.resolveMethod() if (method != null) { val containingClass = method.containingClass @@ -98,17 +105,15 @@ visitMethodUsage(method, qualifier, expression) } - - super.visitMethodCallExpression(expression) } override fun visitNewExpression(expression: PsiNewExpression) { + super.visitNewExpression(expression) + val constructor = expression.resolveConstructor() if (constructor != null) { visitMethodUsage(constructor, constructor.containingClass!!, expression) } - - super.visitNewExpression(expression) } override fun visitForeachStatement(statement: PsiForeachStatement) { Index: src/main/kotlin/platform/mixin/handlers/mixinextras/ExpressionInjectionPoint.kt =================================================================== --- src/main/kotlin/platform/mixin/handlers/mixinextras/ExpressionInjectionPoint.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/handlers/mixinextras/ExpressionInjectionPoint.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -41,7 +41,6 @@ import com.demonwav.mcdev.util.findAnnotations import com.demonwav.mcdev.util.findContainingModifierList import com.demonwav.mcdev.util.findModule -import com.demonwav.mcdev.util.findMultiInjectionHost import com.demonwav.mcdev.util.ifEmpty import com.demonwav.mcdev.util.parseArray import com.demonwav.mcdev.util.resolveType @@ -262,6 +261,8 @@ private val matchContext: MESourceMatchContext ) : NavigationVisitor() { override fun visitElement(element: PsiElement) { + // TODO: preprocess Java tree to find and exclude synthetic elements (as mixin extras does in bytecode). + // Synthetic elements include, for example, the Objects.requireNonNull for method reference receivers. for (statement in statements) { if (statement.matchesJava(element, matchContext)) { if (matchContext.captures.isNotEmpty()) { Index: src/main/kotlin/platform/mixin/util/AsmUtil.kt =================================================================== --- src/main/kotlin/platform/mixin/util/AsmUtil.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/platform/mixin/util/AsmUtil.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -782,7 +782,7 @@ // walk inside the reference first, visits the qualifier first (it's first in the bytecode) super.visitMethodReferenceExpression(expression) - if (expression.hasSyntheticMethod) { + if (expression.hasSyntheticMethod(clazz.version)) { if (matcher.accept(expression)) { stopWalking() } Index: src/main/kotlin/util/PsiChildPointer.kt =================================================================== --- src/main/kotlin/util/PsiChildPointer.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) +++ src/main/kotlin/util/PsiChildPointer.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -0,0 +1,83 @@ +/* + * Minecraft Development for IntelliJ + * + * https://mcdev.io/ + * + * Copyright (C) 2025 minecraft-dev + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation, version 3.0 only. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program. If not, see . + */ + +package com.demonwav.mcdev.util + +import com.intellij.psi.PsiElement +import com.intellij.psi.util.PsiTreeUtil +import java.lang.ref.WeakReference + +/** + * A pointer to a [PsiElement] within the subtree of another [PsiElement], that can survive copies of the parent + * element and find the same child in the copy. + */ +class PsiChildPointer internal constructor( + private var element: WeakReference, + private val path: List +) { + fun dereference(parent: PsiElement): T? { + val element = element.get() + if (element != null && PsiTreeUtil.isAncestor(parent, element, false)) { + return element + } + + var result = parent + for ((index, type) in path) { + result = generateSequence(result.firstChild) { it.nextSibling } + .filter { it.javaClass === type } + .drop(index) + .firstOrNull() + ?: return null + } + + @Suppress("UNCHECKED_CAST") // we checked the Class, it's part of the path + return (result as T).also { this.element = WeakReference(it) } + } + + internal data class PathElement(val index: Int, val type: Class) +} + +fun PsiElement.createChildPointer(child: T): PsiChildPointer { + if (child === this) { + return PsiChildPointer(WeakReference(this), emptyList()) + } + + val path = mutableListOf() + var element: PsiElement = child + var parent = child.parent + while (parent != null) { + val type = element.javaClass + val indexInParent = generateSequence(parent.firstChild) { it.nextSibling } + .filter { it.javaClass === type } + .indexOf(element) + + path += PsiChildPointer.PathElement(indexInParent, type) + + if (parent === this) { + path.reverse() + return PsiChildPointer(WeakReference(child), path) + } + + element = parent + parent = element.parent + } + + throw IllegalArgumentException("$child is not a child of $this") +} Index: src/main/kotlin/util/bytecode-utils.kt =================================================================== --- src/main/kotlin/util/bytecode-utils.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/util/bytecode-utils.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -25,12 +25,16 @@ import com.intellij.psi.PsiArrayType import com.intellij.psi.PsiClass import com.intellij.psi.PsiClassType -import com.intellij.psi.PsiField +import com.intellij.psi.PsiDisjunctionType +import com.intellij.psi.PsiIntersectionType import com.intellij.psi.PsiMethod import com.intellij.psi.PsiModifier import com.intellij.psi.PsiPrimitiveType import com.intellij.psi.PsiType +import com.intellij.psi.PsiTypeParameter import com.intellij.psi.PsiTypes +import com.intellij.psi.PsiVariable +import com.intellij.psi.PsiWildcardType import com.intellij.psi.search.GlobalSearchScope import com.intellij.psi.util.TypeConversionUtil import org.jetbrains.plugins.groovy.lang.resolve.processors.inference.type @@ -91,10 +95,91 @@ is PsiPrimitiveType -> builder.append(internalName) is PsiArrayType -> componentType.appendDescriptor(builder.append('[')) is PsiClassType -> appendInternalName(builder.append('L')).append(';') + is PsiWildcardType -> extendsBound.appendDescriptor(builder) + is PsiIntersectionType -> conjuncts.first().appendDescriptor(builder) + is PsiDisjunctionType -> leastUpperBound.appendDescriptor(builder) else -> throw IllegalArgumentException("Unsupported PsiType: $this") } } +val PsiType.signature + get() = appendSignature(StringBuilder()).toString() + +private fun PsiType.appendSignature(builder: StringBuilder): StringBuilder { + return when (this) { + is PsiPrimitiveType -> builder.append(internalName) + is PsiArrayType -> componentType.appendSignature(builder.append('[')) + is PsiClassType -> { + val resolveResult = resolveGenerics() + val resolved = resolveResult.element ?: return builder + val substitutions = resolveResult.substitutor.substitutionMap + if (resolved is PsiTypeParameter) { + builder.append('T').append(resolved.name).append(';') + } else { + builder.append('L') + val classes = generateSequence(resolved) { it.containingClass }.toList() + var firstClass = true + var hadGenerics = false + for (clazz in classes.asReversed()) { + if (firstClass) { + clazz.appendInternalName(builder) + firstClass = false + } else { + if (hadGenerics) { + builder.append('.') + } else { + builder.append('$') + } + builder.append(clazz.name) + } + + val typeArgs = clazz.typeParameterList?.typeParameters?.map(substitutions::get) + ?: emptyList() + if (typeArgs.isNotEmpty() && typeArgs.all { it != null }) { + hadGenerics = true + builder.append('<') + for (typeArg in typeArgs) { + typeArg!!.appendSignature(builder) + } + builder.append('>') + } + } + builder.append(';') + } + } + is PsiIntersectionType -> conjuncts.first().appendSignature(builder) + is PsiDisjunctionType -> leastUpperBound.appendSignature(builder) + is PsiWildcardType -> when { + isExtends -> extendsBound.appendSignature(builder.append('+')) + isSuper -> superBound.appendSignature(builder.append('-')) + else -> builder.append('*') + } + else -> throw IllegalArgumentException("Unsupported PsiType: $this") + } +} + +private fun PsiTypeParameter.appendSignature(builder: StringBuilder): StringBuilder { + builder.append(name) + + val extendsList = this.extendsList.referencedTypes + if (extendsList.isEmpty()) { + return builder.append(":Ljava/lang/Object;") + } + + val classBound = extendsList.first().takeIf { classBound -> + classBound.resolve()?.isInterface != true + } + + builder.append(':') + classBound?.appendSignature(builder) + + for (interfaceBound in extendsList.drop(if (classBound != null) 1 else 0)) { + interfaceBound.appendSignature(builder.append(':')) + } + + return builder +} + fun parseClassDescriptor(descriptor: String): String { val internalName = descriptor.substring(1, descriptor.length - 1) return internalName.replace('/', '.') @@ -132,6 +217,32 @@ } } +val PsiClass.signature: String + get() { + val builder = StringBuilder() + + val typeParams = typeParameterList?.typeParameters + if (!typeParams.isNullOrEmpty()) { + builder.append('<') + for (typeParam in typeParams) { + typeParam.appendSignature(builder) + } + builder.append('>') + } + + val superType = this.extendsListTypes.singleOrNull() + if (superType == null || isInterface) { + builder.append("Ljava/lang/Object;") + } else { + superType.appendSignature(builder) + } + val interfaces = if (isInterface) this.extendsListTypes else this.implementsListTypes + for (itf in interfaces) { + itf.appendSignature(builder) + } + return builder.toString() + } + fun PsiClass.findMethodsByInternalName(internalName: String, checkBases: Boolean = false): Array { return if (internalName == INTERNAL_CONSTRUCTOR_NAME) { constructors @@ -161,6 +272,33 @@ } } +val PsiMethod.signature: String + get() { + val builder = StringBuilder() + + val typeParams = typeParameterList?.typeParameters + if (!typeParams.isNullOrEmpty()) { + builder.append('<') + for (typeParam in typeParams) { + typeParam.appendSignature(builder) + } + builder.append('>') + } + + builder.append('(') + for (parameter in parameterList.parameters) { + parameter.type.appendSignature(builder) + } + builder.append(')') + returnType?.appendSignature(builder) + + for (exception in throwsList.referencedTypes) { + exception.appendSignature(builder.append('^')) + } + + return builder.toString() + } + @Throws(ClassNameResolutionFailedException::class) private fun PsiMethod.appendDescriptor(builder: StringBuilder): StringBuilder { builder.append('(') @@ -179,7 +317,7 @@ } // Field -val PsiField.descriptor: String? +val PsiVariable.descriptor: String? get() { return try { appendDescriptor(StringBuilder()).toString() @@ -188,5 +326,8 @@ } } +val PsiVariable.signature: String + get() = type.signature + @Throws(ClassNameResolutionFailedException::class) -private fun PsiField.appendDescriptor(builder: StringBuilder): StringBuilder = type.appendDescriptor(builder) +private fun PsiVariable.appendDescriptor(builder: StringBuilder): StringBuilder = type.appendDescriptor(builder) Index: src/main/kotlin/util/psi-utils.kt =================================================================== --- src/main/kotlin/util/psi-utils.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/main/kotlin/util/psi-utils.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -23,6 +23,7 @@ import com.demonwav.mcdev.facet.MinecraftFacet import com.demonwav.mcdev.platform.mcp.McpModule import com.demonwav.mcdev.platform.mcp.McpModuleType +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil import com.intellij.codeInsight.lookup.LookupElementBuilder import com.intellij.debugger.impl.DebuggerUtilsEx import com.intellij.ide.highlighter.JavaClassFileType @@ -42,9 +43,14 @@ import com.intellij.psi.ElementManipulator import com.intellij.psi.ElementManipulators import com.intellij.psi.JavaPsiFacade +import com.intellij.psi.LambdaUtil import com.intellij.psi.PsiAnnotation +import com.intellij.psi.PsiArrayType +import com.intellij.psi.PsiCapturedWildcardType import com.intellij.psi.PsiClass +import com.intellij.psi.PsiClassType import com.intellij.psi.PsiDirectory +import com.intellij.psi.PsiDisjunctionType import com.intellij.psi.PsiDocumentManager import com.intellij.psi.PsiElement import com.intellij.psi.PsiElementFactory @@ -52,6 +58,7 @@ import com.intellij.psi.PsiEllipsisType import com.intellij.psi.PsiExpression import com.intellij.psi.PsiFile +import com.intellij.psi.PsiIntersectionType import com.intellij.psi.PsiKeyword import com.intellij.psi.PsiLanguageInjectionHost import com.intellij.psi.PsiManager @@ -67,7 +74,11 @@ import com.intellij.psi.PsiPrimitiveType import com.intellij.psi.PsiReference import com.intellij.psi.PsiReferenceExpression +import com.intellij.psi.PsiSuperExpression import com.intellij.psi.PsiType +import com.intellij.psi.PsiTypeElement +import com.intellij.psi.PsiTypeParameter +import com.intellij.psi.PsiWildcardType import com.intellij.psi.ResolveResult import com.intellij.psi.filters.ElementFilter import com.intellij.psi.search.GlobalSearchScope @@ -76,6 +87,7 @@ import com.intellij.psi.util.CachedValuesManager import com.intellij.psi.util.PsiTreeUtil import com.intellij.psi.util.PsiTypesUtil +import com.intellij.psi.util.PsiUtil import com.intellij.psi.util.TypeConversionUtil import com.intellij.psi.util.parentOfType import com.intellij.refactoring.changeSignature.ChangeSignatureUtil @@ -278,9 +290,21 @@ if (normalized is PsiEllipsisType) { normalized = normalized.toArrayType() } + normalized = normalized.removeWildcards() + if (normalized is PsiIntersectionType) { + normalized = normalized.conjuncts.first() + } return normalized } +fun PsiType.removeWildcards(): PsiType { + return when (this) { + is PsiWildcardType -> extendsBound + is PsiCapturedWildcardType -> upperBound + else -> this + } +} + fun PsiType.toObjectType(project: Project): PsiType = when (val normalized = normalize()) { is PsiPrimitiveType -> @@ -371,22 +395,105 @@ } } -@Suppress("PrivatePropertyName") private val REAL_NAME_KEY = Key("mcdev.real_name") var PsiMember.realName: String? get() = getUserData(REAL_NAME_KEY) set(value) = putUserData(REAL_NAME_KEY, value) -val PsiMethodReferenceExpression.hasSyntheticMethod: Boolean +// see com.sun.tools.javac.comp.TransTypes.needsConversionToLambda +fun PsiMethodReferenceExpression.hasSyntheticMethod(classVersion: Int): Boolean { + val qualifier = this.qualifier ?: return true + + if (qualifier is PsiTypeElement && qualifier.type is PsiArrayType) { + return true + } + + if (qualifier is PsiSuperExpression) { + return true + } + + val referencedClass = when (qualifier) { + is PsiTypeElement -> (qualifier.type as? PsiClassType)?.resolve() + is PsiReferenceExpression -> qualifier.resolve() as? PsiClass + else -> null + } + + if (isConstructor) { + if (referencedClass?.containingClass != null && !referencedClass.hasModifierProperty(PsiModifier.STATIC)) { + return true + } + if (referencedClass != null && PsiUtil.isLocalOrAnonymousClass(referencedClass)) { + return true + } + } + + if (isVarArgsCall) { + return true + } + + if (DesugarUtil.needsBridgeMethod(this, classVersion)) { + return true + } + + // even if a bridge method isn't required, if the method is protected in a different package, a synthetic method is + // still required, because otherwise the synthetic class that LambdaMetafactory creates won't be able to access it + val resolved = resolve() ?: return true + when (resolved) { + is PsiClass -> return !isConstructor + !is PsiMethod -> return true + } + if (resolved.hasModifierProperty(PsiModifier.PROTECTED) && findContainingClass()?.packageName != referencedClass?.packageName) { + return true + } + + val functionalInterfaceType = this.functionalInterfaceType ?: return true + val interfaceMethod = LambdaUtil.getFunctionalInterfaceMethod(functionalInterfaceType) ?: return true + + return interfaceMethod.parameterList.parameters.any { param -> + var paramType = param.type + while (paramType is PsiClassType) { + val resolved = paramType.resolve() + if (resolved is PsiTypeParameter) { + val extendsList = resolved.extendsList.referencedTypes + when (extendsList.size) { + 0 -> break + 1 -> paramType = extendsList.single() + else -> return@any true + } + } + } + paramType is PsiIntersectionType || paramType is PsiDisjunctionType + } +} + +private val PsiMethodReferenceExpression.isVarArgsCall: Boolean get() { - // the only method reference that doesn't have a synthetic method is a direct reference to a method - if (referenceName == "new") return true - val qualifier = this.qualifier - if (qualifier !is PsiReferenceExpression) return true - return qualifier.resolve() !is PsiClass + val resolveResult = advancedResolve(false) + val resolvedMethod = resolveResult.element as? PsiMethod ?: return false + if (!resolvedMethod.isVarArgs) { + return false - } + } + val functionalInterfaceType = this.functionalInterfaceType ?: return true + val functionalResolveResult = PsiUtil.resolveGenericsClassInType(functionalInterfaceType) + val interfaceMethod = LambdaUtil.getFunctionalInterfaceMethod(functionalResolveResult) ?: return true + + val interfaceSignature = interfaceMethod.getSignature(LambdaUtil.getSubstitutor(interfaceMethod, functionalResolveResult)) + val interfaceParamTypes = interfaceSignature.parameterTypes + + val resolvedParams = resolvedMethod.parameterList.parameters + val isStatic = resolvedMethod.hasModifierProperty(PsiModifier.STATIC) + val effectiveNumParams = if (isStatic) resolvedParams.size else resolvedParams.size + 1 + if (effectiveNumParams != interfaceParamTypes.size) { + return true + } + + val varArgsType = resolvedParams.lastOrNull()?.type as? PsiEllipsisType ?: return true + val arrayType = resolveResult.substitutor.substitute(varArgsType.toArrayType()) + return !arrayType.isAssignableFrom(interfaceParamTypes.last()) + } + val PsiClass.psiType: PsiType get() = PsiTypesUtil.getClassType(this) Index: src/test/kotlin/platform/mixin/desugar/AbstractDesugarMultiTest.kt =================================================================== --- src/test/kotlin/platform/mixin/desugar/AbstractDesugarMultiTest.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) +++ src/test/kotlin/platform/mixin/desugar/AbstractDesugarMultiTest.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -0,0 +1,162 @@ +/* + * Minecraft Development for IntelliJ + * + * https://mcdev.io/ + * + * Copyright (C) 2025 minecraft-dev + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation, version 3.0 only. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program. If not, see . + */ + +package com.demonwav.mcdev.platform.mixin.desugar + +import com.demonwav.mcdev.framework.BaseMinecraftTest +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarContext +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil +import com.demonwav.mcdev.platform.mixin.handlers.desugar.Desugarer +import com.demonwav.mcdev.util.createChildPointer +import com.intellij.openapi.command.WriteCommandAction +import com.intellij.psi.PsiClass +import com.intellij.psi.PsiElement +import com.intellij.psi.PsiJavaFile +import com.intellij.psi.PsiMethodCallExpression +import com.intellij.psi.codeStyle.CodeStyleManager +import com.intellij.psi.codeStyle.JavaCodeStyleManager +import com.intellij.psi.util.PsiTreeUtil +import com.intellij.psi.util.childrenOfType +import com.intellij.psi.util.parentOfType +import com.intellij.testFramework.IndexingTestUtil +import kotlin.collections.filterIsInstance +import org.intellij.lang.annotations.Language +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertInstanceOf +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertTrue +import org.objectweb.asm.Opcodes + +abstract class AbstractDesugarMultiTest : BaseMinecraftTest() { + companion object { + private const val DEFAULT_CLASS_VERSION = Opcodes.V21 + } + + abstract val desugarers: List + + protected fun doTestNoChange(@Language("JAVA") code: String, classVersion: Int = DEFAULT_CLASS_VERSION) { + doTest(code, code, classVersion) + } + + protected fun doTest( + @Language("JAVA") before: String, + @Language("JAVA") after: String, + classVersion: Int = DEFAULT_CLASS_VERSION + ): PsiElement? { + return WriteCommandAction.runWriteCommandAction(project) { + val codeStyleManager = CodeStyleManager.getInstance(project) + val javaCodeStyleManager = JavaCodeStyleManager.getInstance(project) + + val caretIndex = after.indexOf("") + val actualAfter = if (caretIndex >= 0) { + after.removeRange(caretIndex, caretIndex + "".length) + } else { + after + } + val expectedFile = fixture.addClass(actualAfter).containingFile + val elementAtCaret = if (caretIndex >= 0) { + val elementAtCaret = expectedFile.findElementAt(caretIndex) + assertNotNull(elementAtCaret, "Could not find element at caret") + expectedFile.createChildPointer(elementAtCaret!!) + } else { + null + } + assertEquals( + expectedFile, + codeStyleManager.reformat(expectedFile), + "Reformatting changed the file!", + ) + val expectedText = expectedFile.text + expectedFile.delete() + + val testFile = assertInstanceOf( + PsiJavaFile::class.java, + fixture.configureByText("Test.java", before) + ) + assertEquals( + testFile, + codeStyleManager.reformat(testFile), + "Reformatting changed the file!", + ) + + IndexingTestUtil.waitUntilIndexesAreReady(project) + + val desugaredFile = testFile.copy() as PsiJavaFile + DesugarUtil.setOriginalRecursive(desugaredFile, testFile) + for (desugarer in desugarers) { + desugarer.desugar(project, desugaredFile, DesugarContext(classVersion)) + } + assertEquals( + desugaredFile, + javaCodeStyleManager.shortenClassReferences(desugaredFile), + "Shortening class references changed the file!" + ) + assertEquals( + expectedText, + codeStyleManager.reformat(desugaredFile.copy()).text + ) + + PsiTreeUtil.processElements(desugaredFile) { desugaredElement -> + val originalElement = DesugarUtil.getOriginalElement(desugaredElement) + if (originalElement != null) { + assertTrue( + PsiTreeUtil.isAncestor(testFile, originalElement, false) + ) { + "The original element of $desugaredElement is not from the original file" + } + } + true + } + + val originalClasses = testFile.childrenOfType() + val desugaredClassesSet = mutableSetOf() + val originalToDesugaredMap = DesugarUtil.getOriginalToDesugaredMap(desugaredFile) + for (clazz in originalClasses) { + val desugaredClasses = originalToDesugaredMap[clazz]?.filterIsInstance() ?: emptyList() + assertEquals(1, desugaredClasses.size) { "Unexpected number of desugared classes for ${clazz.name}" } + desugaredClassesSet += desugaredClasses.first() + } + assertEquals(originalClasses.size, desugaredClassesSet.size, "Unexpected number of desugared classes") + + elementAtCaret?.let { + val result = elementAtCaret.dereference(desugaredFile) + assertNotNull(result, "Could not dereference element at caret") + result + } + } + } + + protected fun doIndyTest( + @Language("JAVA") before: String, + @Language("JAVA") after: String, + expectedIndyData: DesugarUtil.IndyData, + classVersion: Int = DEFAULT_CLASS_VERSION, + ) { + WriteCommandAction.runWriteCommandAction(project) { + val element = doTest(before, after, classVersion) + assertNotNull(element, "No caret found") + val indyCall = element!!.parentOfType() + assertNotNull(indyCall, "No method call found") + val indyData = DesugarUtil.getIndyData(indyCall!!) + assertNotNull(indyData, "Method call has no indy data") + assertEquals(expectedIndyData, indyData) + } + } +} Index: src/test/kotlin/platform/mixin/desugar/AbstractDesugarTest.kt =================================================================== --- src/test/kotlin/platform/mixin/desugar/AbstractDesugarTest.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/test/kotlin/platform/mixin/desugar/AbstractDesugarTest.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -20,86 +20,9 @@ package com.demonwav.mcdev.platform.mixin.desugar -import com.demonwav.mcdev.framework.BaseMinecraftTest -import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarContext -import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil import com.demonwav.mcdev.platform.mixin.handlers.desugar.Desugarer -import com.intellij.openapi.command.WriteCommandAction -import com.intellij.psi.PsiClass -import com.intellij.psi.PsiJavaFile -import com.intellij.psi.codeStyle.CodeStyleManager -import com.intellij.psi.codeStyle.JavaCodeStyleManager -import com.intellij.psi.util.PsiTreeUtil -import com.intellij.psi.util.childrenOfType -import com.intellij.testFramework.IndexingTestUtil -import org.intellij.lang.annotations.Language -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.Assertions.assertInstanceOf -import org.junit.jupiter.api.Assertions.assertTrue -import org.objectweb.asm.Opcodes -abstract class AbstractDesugarTest : BaseMinecraftTest() { +abstract class AbstractDesugarTest : AbstractDesugarMultiTest() { + final override val desugarers get() = listOf(desugarer) abstract val desugarer: Desugarer - - protected fun doTestNoChange(@Language("JAVA") code: String, classVersion: Int = Opcodes.V21) { - doTest(code, code, classVersion) - } +} - - protected fun doTest(@Language("JAVA") before: String, @Language("JAVA") after: String, classVersion: Int = Opcodes.V21) { - WriteCommandAction.runWriteCommandAction(project) { - val codeStyleManager = CodeStyleManager.getInstance(project) - val javaCodeStyleManager = JavaCodeStyleManager.getInstance(project) - - val expectedFile = fixture.addClass(after).containingFile - assertEquals( - expectedFile, - codeStyleManager.reformat(expectedFile), - "Reformatting changed the file!", - ) - val expectedText = expectedFile.text - expectedFile.delete() - - val testFile = assertInstanceOf( - PsiJavaFile::class.java, - fixture.configureByText("Test.java", before) - ) - assertEquals( - testFile, - codeStyleManager.reformat(testFile), - "Reformatting changed the file!", - ) - - IndexingTestUtil.waitUntilIndexesAreReady(project) - - val desugaredFile = testFile.copy() as PsiJavaFile - DesugarUtil.setOriginalRecursive(desugaredFile, testFile) - desugarer.desugar(project, desugaredFile, DesugarContext(classVersion)) - assertEquals( - expectedText, - codeStyleManager.reformat(javaCodeStyleManager.shortenClassReferences(desugaredFile.copy())).text - ) - - PsiTreeUtil.processElements(desugaredFile) { desugaredElement -> - val originalElement = DesugarUtil.getOriginalElement(desugaredElement) - if (originalElement != null) { - assertTrue( - PsiTreeUtil.isAncestor(testFile, originalElement, false) - ) { - "The original element of $desugaredElement is not from the original file" - } - } - true - } - - val originalClasses = testFile.childrenOfType() - val desugaredClassesSet = mutableSetOf() - val originalToDesugaredMap = DesugarUtil.getOriginalToDesugaredMap(desugaredFile) - for (clazz in originalClasses) { - val desugaredClasses = originalToDesugaredMap[clazz]?.filterIsInstance() ?: emptyList() - assertEquals(1, desugaredClasses.size) { "Unexpected number of desugared classes for ${clazz.name}" } - desugaredClassesSet += desugaredClasses.first() - } - assertEquals(originalClasses.size, desugaredClassesSet.size, "Unexpected number of desugared classes") - } - } -} Index: src/test/kotlin/platform/mixin/desugar/AnonymousAndLocalClassDesugarTest.kt =================================================================== --- src/test/kotlin/platform/mixin/desugar/AnonymousAndLocalClassDesugarTest.kt (revision 12c8403604e773544c6ffc47a4e599d47841b9c6) +++ src/test/kotlin/platform/mixin/desugar/AnonymousAndLocalClassDesugarTest.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -938,53 +938,6 @@ } @Test - fun testLocalWithCaptureCreatedWithConstructorReference() { - doTest( - """ - import java.util.function.Supplier; - - class Test { - void test() { - String hello = "Hello"; - - class Local { - void print() { - System.out.println(hello); - } - } - - Supplier supplier = Local::new; - } - } - """.trimIndent(), - """ - import java.util.function.Supplier; - - class Test { - void test() { - String hello = "Hello"; - - Supplier<$1Local> supplier = () -> new $1Local(hello); - } - - class $1Local { - final String val${'$'}hello; - - $1Local(String hello) { - this.val${'$'}hello = hello; - super(); - } - - void print() { - System.out.println(val${'$'}hello); - } - } - } - """.trimIndent() - ) - } - - @Test fun testCaptureUsedInConstructor() { doTest( """ Index: src/test/kotlin/platform/mixin/desugar/LambdaDesugarTest.kt =================================================================== --- src/test/kotlin/platform/mixin/desugar/LambdaDesugarTest.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) +++ src/test/kotlin/platform/mixin/desugar/LambdaDesugarTest.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -0,0 +1,872 @@ +/* + * Minecraft Development for IntelliJ + * + * https://mcdev.io/ + * + * Copyright (C) 2025 minecraft-dev + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation, version 3.0 only. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program. If not, see . + */ + +package com.demonwav.mcdev.platform.mixin.desugar + +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil +import com.demonwav.mcdev.platform.mixin.handlers.desugar.LambdaDesugarer +import java.lang.invoke.LambdaMetafactory +import org.junit.jupiter.api.Test +import org.objectweb.asm.Handle +import org.objectweb.asm.Opcodes +import org.objectweb.asm.Type + +@Suppress( + "Convert2MethodRef", + "ObviousNullCheck", + "rawtypes", + "RedundantCast", + "ResultOfMethodCallIgnored", + "SwitchStatementWithTooFewBranches", + "unchecked", + "UnnecessarySemicolon" +) +class LambdaDesugarTest : AbstractDesugarTest() { + override val desugarer = LambdaDesugarer + + companion object { + internal val metafactory = Handle( + Opcodes.H_INVOKESTATIC, + "java/lang/invoke/LambdaMetafactory", + "metafactory", + "(Ljava/lang/invoke/MethodHandles\$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;", + false + ) + private val altMetafactory = Handle( + Opcodes.H_INVOKESTATIC, + "java/lang/invoke/LambdaMetafactory", + "altMetafactory", + "(Ljava/lang/invoke/MethodHandles\$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;", + false + ) + } + + @Test + fun testSimpleMethodReference() { + doIndyTest( + """ + class Test { + void test() { + Runnable r = Test::method; + } + + static void method() { + } + } + """.trimIndent(), + """ + class Test { + private static Runnable synthetic1() { + } + + void test() { + Runnable r = Test.synthetic1(); + } + + static void method() { + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "run", + "()Ljava/lang/Runnable;", + metafactory, + Type.getMethodType("()V"), + Handle( + Opcodes.H_INVOKESTATIC, + "Test", + "method", + "()V", + false + ), + Type.getMethodType("()V") + ) + ) + } + + @Test + fun testBoundMethodReference() { + doIndyTest( + """ + class Test { + void test() { + Runnable r = this::method; + } + + void method() { + } + } + """.trimIndent(), + """ + import java.util.Objects; + + class Test { + private static Runnable synthetic1(Test param1) { + } + + void test() { + Runnable r = Test.synthetic1(Objects.requireNonNull(this)); + } + + void method() { + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "run", + "(LTest;)Ljava/lang/Runnable;", + metafactory, + Type.getMethodType("()V"), + Handle( + Opcodes.H_INVOKEVIRTUAL, + "Test", + "method", + "()V", + false + ), + Type.getMethodType("()V"), + ) + ) + } + + @Test + fun testConstructorMethodReference() { + doIndyTest( + """ + import java.util.function.Supplier; + + class Test { + void test() { + Supplier s = Test::new; + } + } + """.trimIndent(), + """ + import java.util.function.Supplier; + + class Test { + private static Supplier synthetic1() { + } + + void test() { + Supplier s = Test.synthetic1(); + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "get", + "()Ljava/util/function/Supplier;", + metafactory, + Type.getMethodType("()Ljava/lang/Object;"), + Handle( + Opcodes.H_NEWINVOKESPECIAL, + "Test", + "", + "()V", + false + ), + Type.getMethodType("()LTest;"), + ) + ) + } + + @Test + fun testSimpleLambda() { + doIndyTest( + """ + class Test { + void test() { + Runnable r = () -> { + }; + } + } + """.trimIndent(), + """ + class Test { + private static void lambda${'$'}test$0() { + } + + private static Runnable synthetic1() { + } + + void test() { + Runnable r = Test.synthetic1(); + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "run", + "()Ljava/lang/Runnable;", + metafactory, + Type.getMethodType("()V"), + Handle( + Opcodes.H_INVOKESTATIC, + "Test", + "lambda\$test$0", + "()V", + false + ), + Type.getMethodType("()V"), + ) + ) + } + + @Test + fun testNonStaticLambda() { + doIndyTest( + """ + class Test { + void test() { + Runnable r = () -> method(); + } + + void method() { + } + } + """.trimIndent(), + """ + class Test { + private static Runnable synthetic1(Test param1) { + } + + void test() { + Runnable r = Test.synthetic1(this); + } + + void method() { + } + + private void lambda${'$'}test$0() { + method(); + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "run", + "(LTest;)Ljava/lang/Runnable;", + metafactory, + Type.getMethodType("()V"), + Handle( + Opcodes.H_INVOKEVIRTUAL, + "Test", + "lambda\$test$0", + "()V", + false + ), + Type.getMethodType("()V"), + ) + ) + } + + @Test + fun testLambdaInConstructor() { + doTest( + """ + class Test { + Test() { + Runnable r = () -> { + }; + } + } + """.trimIndent(), + """ + class Test { + Test() { + Runnable r = Test.synthetic1(); + } + + private static void lambda${'$'}new$0() { + } + + private static Runnable synthetic1() { + } + } + """.trimIndent() + ) + } + + @Test + fun testLambdaInStaticInitializer() { + doTest( + """ + class Test { + static { + Runnable r = () -> { + }; + } + } + """.trimIndent(), + """ + class Test { + static { + Runnable r = Test.synthetic1(); + } + + private static void lambda${'$'}static$0() { + } + + private static Runnable synthetic1() { + } + } + """.trimIndent() + ) + } + + @Test + fun testLambdaWithCapture() { + doIndyTest( + """ + class Test { + void test(String capture) { + Runnable r = () -> System.out.println(capture); + } + } + """.trimIndent(), + """ + class Test { + private static void lambda${'$'}test$0(String capture) { + System.out.println(capture); + } + + private static Runnable synthetic1(String param1) { + } + + void test(String capture) { + Runnable r = Test.synthetic1(capture); + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "run", + "(Ljava/lang/String;)Ljava/lang/Runnable;", + metafactory, + Type.getMethodType("()V"), + Handle( + Opcodes.H_INVOKESTATIC, + "Test", + "lambda\$test$0", + "(Ljava/lang/String;)V", + false + ), + Type.getMethodType("()V"), + ) + ) + } + + @Test + fun testNonStaticLambdaWithCapture() { + doIndyTest( + """ + class Test { + void test(String capture) { + Runnable r = () -> method(capture); + } + + void method(String s) { + } + } + """.trimIndent(), + """ + class Test { + private static Runnable synthetic1(Test param1, String param2) { + } + + void test(String capture) { + Runnable r = Test.synthetic1(this, capture); + } + + void method(String s) { + } + + private void lambda${'$'}test$0(String capture) { + method(capture); + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "run", + "(LTest;Ljava/lang/String;)Ljava/lang/Runnable;", + metafactory, + Type.getMethodType("()V"), + Handle( + Opcodes.H_INVOKEVIRTUAL, + "Test", + "lambda\$test$0", + "(Ljava/lang/String;)V", + false + ), + Type.getMethodType("()V"), + ) + ) + } + + @Test + fun testLambdaMarkerInterface() { + doIndyTest( + """ + class Test { + void test() { + var r = (Runnable & Marker) () -> { + }; + } + + interface Marker { + } + } + """.trimIndent(), + """ + class Test { + private static void lambda${'$'}test$0() { + } + + private static Runnable synthetic1() { + } + + void test() { + var r = (Runnable & Marker) Test.synthetic1(); + } + + interface Marker { + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "run", + "()Ljava/lang/Runnable;", + altMetafactory, + Type.getMethodType("()V"), + Handle( + Opcodes.H_INVOKESTATIC, + "Test", + "lambda\$test$0", + "()V", + false + ), + Type.getMethodType("()V"), + LambdaMetafactory.FLAG_MARKERS or LambdaMetafactory.FLAG_BRIDGES, + 1, + Type.getType("LTest\$Marker;"), + 0, + ) + ) + } + + @Test + fun testSerializableLambda() { + doIndyTest( + """ + import java.io.Serializable; + + class Test { + void test() { + var r = (Runnable & Serializable) () -> { + }; + } + } + """.trimIndent(), + """ + import java.io.Serializable; + import java.lang.invoke.SerializedLambda; + + class Test { + private static void lambda${'$'}test$2feeadd5$1() { + } + + private static Runnable synthetic1() { + } + + private static Object ${'$'}deserializeLambda$(SerializedLambda serializedLambda) { + switch (serializedLambda.getImplMethodName()) { + case "lambda${'$'}test$2feeadd5$1": + if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("java/lang/Runnable") && serializedLambda.getFunctionalInterfaceMethodName().equals("run") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()V") && serializedLambda.getImplClass().equals("Test") && serializedLambda.getImplMethodSignature().equals("()V")) + synthetic1(); + break; + } + throw new IllegalArgumentException("Invalid lambda deserialization"); + } + + void test() { + var r = (Runnable & Serializable) Test.synthetic1(); + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "run", + "()Ljava/lang/Runnable;", + altMetafactory, + Type.getMethodType("()V"), + Handle( + Opcodes.H_INVOKESTATIC, + "Test", + "lambda\$test$2feeadd5$1", + "()V", + false + ), + Type.getMethodType("()V"), + LambdaMetafactory.FLAG_SERIALIZABLE or LambdaMetafactory.FLAG_BRIDGES, + 0, + ) + ) + } + + @Test + fun testMultipleSerializableMethodReferences() { + doIndyTest( + """ + import java.io.Serializable; + + class Test { + I test1() { + return (I & Serializable) Test::method; + } + + J test2() { + return (J & Serializable) Test::method; + } + + private static void method() { + } + + @FunctionalInterface + interface I { + void foo(); + } + + @FunctionalInterface + interface J { + void foo(); + } + } + """.trimIndent(), + """ + import java.io.Serializable; + import java.lang.invoke.SerializedLambda; + + class Test { + private static J synthetic1() { + } + + private static I synthetic2() { + } + + private static Object ${'$'}deserializeLambda$(SerializedLambda serializedLambda) { + switch (serializedLambda.getImplMethodName()) { + case "method": + if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("Test${'$'}I") && serializedLambda.getFunctionalInterfaceMethodName().equals("foo") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()V") && serializedLambda.getImplClass().equals("Test") && serializedLambda.getImplMethodSignature().equals("()V")) + synthetic2(); + if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("Test${'$'}J") && serializedLambda.getFunctionalInterfaceMethodName().equals("foo") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()V") && serializedLambda.getImplClass().equals("Test") && serializedLambda.getImplMethodSignature().equals("()V")) + synthetic1(); + break; + } + throw new IllegalArgumentException("Invalid lambda deserialization"); + } + + I test1() { + return (I & Serializable) Test.synthetic2(); + } + + J test2() { + return (J & Serializable) Test.synthetic1(); + } + + private static void method() { + } + + @FunctionalInterface + interface I { + void foo(); + } + + @FunctionalInterface + interface J { + void foo(); + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "foo", + "()LTest\$I;", + altMetafactory, + Type.getMethodType("()V"), + Handle( + Opcodes.H_INVOKESTATIC, + "Test", + "method", + "()V", + false + ), + Type.getMethodType("()V"), + LambdaMetafactory.FLAG_SERIALIZABLE or LambdaMetafactory.FLAG_BRIDGES, + 0, + ) + ) + } + + @Test + fun testSerializableLambdaWithCapture() { + doIndyTest( + """ + import java.io.Serializable; + + class Test { + void test(String capture1, int capture2, Object capture3) { + var r = (Runnable & Serializable) () -> System.out.println(capture1 + capture2 + capture3); + } + } + """.trimIndent(), + """ + import java.io.Serializable; + import java.lang.invoke.SerializedLambda; + + class Test { + private static void lambda${'$'}test$96740b79$1(String capture1, int capture2, Object capture3) { + System.out.println(capture1 + capture2 + capture3); + } + + private static Runnable synthetic1(String param1, int param2, Object param3) { + } + + private static Object ${'$'}deserializeLambda$(SerializedLambda serializedLambda) { + switch (serializedLambda.getImplMethodName()) { + case "lambda${'$'}test$96740b79$1": + if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("java/lang/Runnable") && serializedLambda.getFunctionalInterfaceMethodName().equals("run") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()V") && serializedLambda.getImplClass().equals("Test") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/String;ILjava/lang/Object;)V")) + synthetic1((String) serializedLambda.getCapturedArg(0), (Integer) serializedLambda.getCapturedArg(1), serializedLambda.getCapturedArg(2)); + break; + } + throw new IllegalArgumentException("Invalid lambda deserialization"); + } + + void test(String capture1, int capture2, Object capture3) { + var r = (Runnable & Serializable) Test.synthetic1(capture1, capture2, capture3); + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "run", + "(Ljava/lang/String;ILjava/lang/Object;)Ljava/lang/Runnable;", + altMetafactory, + Type.getMethodType("()V"), + Handle( + Opcodes.H_INVOKESTATIC, + "Test", + "lambda\$test$96740b79$1", + "(Ljava/lang/String;ILjava/lang/Object;)V", + false + ), + Type.getMethodType("()V"), + LambdaMetafactory.FLAG_SERIALIZABLE or LambdaMetafactory.FLAG_BRIDGES, + 0, + ) + ) + } + + @Test + fun testLambdaBridgeMethods() { + doIndyTest( + """ + class Test { + void test() { + K k = () -> ""; + } + + @FunctionalInterface + interface I { + String foo(); + } + + @FunctionalInterface + interface J { + Object foo(); + } + + interface K extends I, J { + } + } + """.trimIndent(), + """ + class Test { + private static String lambda${'$'}test$0() { + return ""; + } + + private static K synthetic1() { + } + + void test() { + K k = Test.synthetic1(); + } + + @FunctionalInterface + interface I { + String foo(); + } + + @FunctionalInterface + interface J { + Object foo(); + } + + interface K extends I, J { + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "foo", + "()LTest\$K;", + altMetafactory, + Type.getMethodType("()Ljava/lang/String;"), + Handle( + Opcodes.H_INVOKESTATIC, + "Test", + "lambda\$test$0", + "()Ljava/lang/String;", + false + ), + Type.getMethodType("()Ljava/lang/String;"), + LambdaMetafactory.FLAG_BRIDGES, + 1, + Type.getMethodType("()Ljava/lang/Object;"), + ) + ) + } + + @Test + fun testLambdaGenericMethod() { + doIndyTest( + """ + import java.util.function.Consumer; + + class Test { + void test(T captured1, U captured2) { + Consumer c = u -> System.out.println("" + captured1 + captured2); + } + } + """.trimIndent(), + """ + import java.util.function.Consumer; + + class Test { + private static void lambda${'$'}test$0(T captured1, U captured2, U u) { + System.out.println("" + captured1 + captured2); + } + + private static Consumer synthetic1(Object param1, Object param2) { + } + + void test(T captured1, U captured2) { + Consumer c = Test.synthetic1(captured1, captured2); + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "accept", + "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/util/function/Consumer;", + metafactory, + Type.getMethodType("(Ljava/lang/Object;)V"), + Handle( + Opcodes.H_INVOKESTATIC, + "Test", + "lambda\$test$0", + "(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)V", + false + ), + Type.getMethodType("(Ljava/lang/Object;)V"), + ) + ) + } + + @Test + fun testLambdaIntersectionParameter() { + doTest( + """ + import java.io.Serializable; + + class Test { + void test() { + I i = t -> { + }; + } + + @FunctionalInterface + interface I { + void accept(T t); + } + } + """.trimIndent(), + """ + import java.io.Serializable; + + class Test { + private static void lambda${'$'}test$0(T t) { + } + + private static I synthetic1() { + } + + void test() { + I i = Test.synthetic1(); + } + + @FunctionalInterface + interface I { + void accept(T t); + } + } + """.trimIndent() + ) + } + + @Test + fun testLambdaWildcard() { + doTest( + """ + import java.util.function.Supplier; + + class Test { + void test() { + method(() -> ""); + } + + void method(Supplier s) { + } + } + """.trimIndent(), + """ + import java.util.function.Supplier; + + class Test { + private static Object lambda${'$'}test$0() { + return ""; + } + + private static Supplier synthetic1() { + } + + void test() { + method(Test.synthetic1()); + } + + void method(Supplier s) { + } + } + """.trimIndent() + ) + } +} Index: src/test/kotlin/platform/mixin/desugar/MethodReferenceAndLambdaDesugarTest.kt =================================================================== --- src/test/kotlin/platform/mixin/desugar/MethodReferenceAndLambdaDesugarTest.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) +++ src/test/kotlin/platform/mixin/desugar/MethodReferenceAndLambdaDesugarTest.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -0,0 +1,133 @@ +/* + * Minecraft Development for IntelliJ + * + * https://mcdev.io/ + * + * Copyright (C) 2025 minecraft-dev + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation, version 3.0 only. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program. If not, see . + */ + +package com.demonwav.mcdev.platform.mixin.desugar + +import com.demonwav.mcdev.platform.mixin.handlers.desugar.DesugarUtil +import com.demonwav.mcdev.platform.mixin.handlers.desugar.LambdaDesugarer +import com.demonwav.mcdev.platform.mixin.handlers.desugar.MethodReferenceToLambdaDesugarer +import org.junit.jupiter.api.Test +import org.objectweb.asm.Handle +import org.objectweb.asm.Opcodes +import org.objectweb.asm.Type + +@Suppress("rawtypes") +class MethodReferenceAndLambdaDesugarTest : AbstractDesugarMultiTest() { + override val desugarers = listOf( + MethodReferenceToLambdaDesugarer, + LambdaDesugarer, + ) + + @Test + fun testQualifiedMethodReference() { + doIndyTest( + """ + import java.util.function.Consumer; + + class Test { + static void test() { + Consumer c = new Test()::method; + } + + private void method(String... args) { + } + } + """.trimIndent(), + """ + import java.util.function.Consumer; + + class Test { + static void test() { + Consumer c = Test.synthetic1(new Test()); + } + + private static void lambda${'$'}test$0(Test self, String args) { + self.method(args); + } + + private static Consumer synthetic1(Test param1) { + } + + private void method(String... args) { + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "accept", + "(LTest;)Ljava/util/function/Consumer;", + LambdaDesugarTest.metafactory, + Type.getMethodType("(Ljava/lang/Object;)V"), + Handle( + Opcodes.H_INVOKESTATIC, + "Test", + "lambda\$test$0", + "(LTest;Ljava/lang/String;)V", + false + ), + Type.getMethodType("(Ljava/lang/String;)V"), + ) + ) + } + + @Test + fun testSuperMethodReference() { + doIndyTest( + """ + import java.util.function.Supplier; + + class Test { + void test() { + Supplier s = super::clone; + } + } + """.trimIndent(), + """ + import java.util.function.Supplier; + + class Test { + private static Supplier synthetic1(Test param1) { + } + + void test() { + Supplier s = Test.synthetic1(this); + } + + private Object lambda${'$'}test$0() { + return super.clone(); + } + } + """.trimIndent(), + DesugarUtil.IndyData( + "get", + "(LTest;)Ljava/util/function/Supplier;", + LambdaDesugarTest.metafactory, + Type.getMethodType("()Ljava/lang/Object;"), + Handle( + Opcodes.H_INVOKEVIRTUAL, + "Test", + "lambda\$test$0", + "()Ljava/lang/Object;", + false + ), + Type.getMethodType("()Ljava/lang/Object;"), + ) + ) + } +} Index: src/test/kotlin/platform/mixin/desugar/MethodReferenceToLambdaDesugarTest.kt =================================================================== --- src/test/kotlin/platform/mixin/desugar/MethodReferenceToLambdaDesugarTest.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) +++ src/test/kotlin/platform/mixin/desugar/MethodReferenceToLambdaDesugarTest.kt (revision 341f481a921b0826535194f4dacf698de07b53ae) @@ -0,0 +1,356 @@ +/* + * Minecraft Development for IntelliJ + * + * https://mcdev.io/ + * + * Copyright (C) 2025 minecraft-dev + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation, version 3.0 only. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program. If not, see . + */ + +package com.demonwav.mcdev.platform.mixin.desugar + +import com.demonwav.mcdev.platform.mixin.handlers.desugar.MethodReferenceToLambdaDesugarer +import org.junit.jupiter.api.Test +import org.objectweb.asm.Opcodes + +@Suppress("Convert2MethodRef") +class MethodReferenceToLambdaDesugarTest : AbstractDesugarTest() { + override val desugarer = MethodReferenceToLambdaDesugarer + + @Test + fun testNoLambda() { + doTestNoChange( + """ + class Test { + Runnable r = Test::run; + + private static void run() { + } + } + """.trimIndent() + ) + } + + @Test + fun testNoLambdaThis() { + doTestNoChange( + """ + class Test { + Runnable r = this::run; + + private void run() { + } + } + """.trimIndent() + ) + } + + @Test + fun testNoLambdaConstructor() { + doTestNoChange( + """ + import java.util.function.Supplier; + + class Test { + Supplier s = Test::new; + } + """.trimIndent() + ) + } + + @Test + fun testLambdaArray() { + doTest( + """ + import java.util.function.UnaryOperator; + + class Test { + UnaryOperator o = Object[]::clone; + } + """.trimIndent(), + """ + import java.util.function.UnaryOperator; + + class Test { + UnaryOperator o = objects -> objects.clone(); + } + """.trimIndent() + ) + } + + @Test + fun testLambdaSuper() { + doTest( + """ + import java.util.function.Supplier; + + class Test { + Supplier s = super::clone; + } + """.trimIndent(), + """ + import java.util.function.Supplier; + + class Test { + Supplier s = () -> super.clone(); + } + """.trimIndent() + ) + } + + @Test + fun testNoLambdaStaticInner() { + doTestNoChange( + """ + import java.util.function.Supplier; + + class Test { + Supplier s = Inner::new; + + static class Inner { + } + } + """.trimIndent() + ) + } + + @Test + fun testLambdaNonStaticInner() { + doTest( + """ + import java.util.function.Supplier; + + class Test { + Supplier s = Inner::new; + + class Inner { + } + } + """.trimIndent(), + """ + import java.util.function.Supplier; + + class Test { + Supplier s = () -> new Inner(); + + class Inner { + } + } + """.trimIndent() + ) + } + + @Test + fun testLambdaVarArgs() { + doTest( + """ + import java.util.function.BiFunction; + + class Test { + BiFunction f = Test::func; + BiFunction f2 = Test::func2; + + private static String func(String... args) { + return args[0]; + } + + private static String func2(String arg1, String... args) { + return arg1; + } + } + """.trimIndent(), + """ + import java.util.function.BiFunction; + + class Test { + BiFunction f = (args, args2) -> func(args, args2); + BiFunction f2 = (arg1, args) -> func2(arg1, args); + + private static String func(String... args) { + return args[0]; + } + + private static String func2(String arg1, String... args) { + return arg1; + } + } + """.trimIndent() + ) + } + + @Test + fun testInstanceLambdaVarArgs() { + doTest( + """ + import java.util.function.BiFunction; + + class Test { + BiFunction f = Test::func; + + private String func(String... args) { + return args[0]; + } + } + """.trimIndent(), + """ + import java.util.function.BiFunction; + + class Test { + BiFunction f = (test, args) -> test.func(args); + + private String func(String... args) { + return args[0]; + } + } + """.trimIndent() + ) + } + + @Test + fun testLambdaNoVarArgs() { + doTestNoChange( + """ + import java.util.function.Function; + + class Test { + Function f = Test::func; + + private static String func(String... args) { + return args[0]; + } + } + """.trimIndent() + ) + } + + @Test + fun testInstanceLambdaNoVarArgs() { + doTestNoChange( + """ + import java.util.function.BiFunction; + + class Test { + BiFunction f = Test::func; + + private String func(String... args) { + return args[0]; + } + } + """.trimIndent() + ) + } + + @Test + fun testNoLambdaPrivateInner() { + doTestNoChange( + """ + class Test { + Runnable r = Inner::method; + + static class Inner { + private static void method() { + } + } + } + """.trimIndent() + ) + } + + @Test + fun testLambdaPrivateInnerJava8() { + doTest( + """ + class Test { + Runnable r = Inner::method; + + static class Inner { + private static void method() { + } + } + } + """.trimIndent(), + """ + class Test { + Runnable r = () -> Inner.method(); + + static class Inner { + private static void method() { + } + } + } + """.trimIndent(), + classVersion = Opcodes.V1_8 + ) + } + + @Test + fun testLambdaProtected() { + doTest( + """ + import java.io.ObjectOutputStream; + import java.util.function.Consumer; + + class Test extends ObjectOutputStream { + Consumer c = this::writeObjectOverride; + } + """.trimIndent(), + """ + import java.io.ObjectOutputStream; + import java.util.function.Consumer; + + class Test extends ObjectOutputStream { + Consumer c = obj -> writeObjectOverride(obj); + } + """.trimIndent() + ) + } + + @Test + fun testLambdaIntersectionType() { + doTest( + """ + import java.io.Serializable; + + class Test { + F f = Test::method; + + private static void method(T t) { + } + + @FunctionalInterface + interface F { + void accept(T t); + } + } + """.trimIndent(), + """ + import java.io.Serializable; + + class Test { + F f = t -> { + method(t); + }; + + private static void method(T t) { + } + + @FunctionalInterface + interface F { + void accept(T t); + } + } + """.trimIndent() + ) + } +}