User: rednesto Date: 16 Jul 23 14:42 Revision: 152abf41a45014c351f14dae843050749be8df06 Summary: Recognize ModifyConstant injections without Constant supplied In this case Mixin filters the constants based on the method's return type. TeamCity URL: https://ci.mcdev.io/viewModification.html?tab=vcsModificationFiles&modId=8635&personal=false Index: src/main/kotlin/platform/mixin/handlers/ModifyConstantHandler.kt =================================================================== --- src/main/kotlin/platform/mixin/handlers/ModifyConstantHandler.kt (revision ea427f844a94df607d0f355fb932e40f98814406) +++ src/main/kotlin/platform/mixin/handlers/ModifyConstantHandler.kt (revision 152abf41a45014c351f14dae843050749be8df06) @@ -38,9 +38,11 @@ import com.intellij.psi.PsiElement import com.intellij.psi.PsiEnumConstant import com.intellij.psi.PsiManager +import com.intellij.psi.PsiMethod import com.intellij.psi.PsiReferenceExpression import com.intellij.psi.PsiType import com.intellij.psi.search.GlobalSearchScope +import com.intellij.psi.util.parentOfType import org.objectweb.asm.Opcodes import org.objectweb.asm.Type import org.objectweb.asm.tree.AbstractInsnNode @@ -141,7 +143,28 @@ targetClass: ClassNode, targetMethod: MethodNode, ): List? { - val constantInfos = getConstantInfos(annotation) ?: return null + val constantInfos = getConstantInfos(annotation) + if (constantInfos == null) { + val method = annotation.parentOfType() + ?: return emptyList() + val returnType = method.returnType + ?: return emptyList() + val constantParamName = method.parameterList.getParameter(0)?.name ?: "constant" + return listOf( + MethodSignature( + listOf( + ParameterGroup(listOf(sanitizedParameter(returnType, constantParamName))), + ParameterGroup( + collectTargetMethodParameters(annotation.project, targetClass, targetMethod), + isVarargs = true, + required = ParameterGroup.RequiredLevel.OPTIONAL, + ), + ), + returnType, + ) + ) + } + val psiManager = PsiManager.getInstance(annotation.project) return constantInfos.asSequence().map { when (it.constantInfo.constant) { @@ -174,8 +197,6 @@ targetClass: ClassNode, targetMethod: MethodNode, ): List { - val constantInfos = getConstantInfos(annotation) ?: return emptyList() - val targetElement = targetMethod.findSourceElement( targetClass, annotation.project, @@ -183,6 +204,31 @@ canDecompile = true, ) ?: return emptyList() + val constantInfos = getConstantInfos(annotation) + if (constantInfos == null) { + val returnType = annotation.parentOfType()?.returnType + ?: return emptyList() + + val collectVisitor = ConstantInjectionPoint.MyCollectVisitor( + annotation.project, + CollectVisitor.Mode.MATCH_ALL, + null, + Type.getType(returnType.descriptor) + ) + collectVisitor.visit(targetMethod) + val bytecodeResults = collectVisitor.result + + val navigationVisitor = ConstantInjectionPoint.MyNavigationVisitor( + null, + Type.getType(returnType.descriptor) + ) + targetElement.accept(navigationVisitor) + + return bytecodeResults.asSequence().mapNotNull { bytecodeResult -> + navigationVisitor.result.getOrNull(bytecodeResult.index) + }.sortedBy { it.textOffset }.toList() + } + val constantInjectionPoint = InjectionPoint.byAtCode("CONSTANT") as? ConstantInjectionPoint ?: return emptyList() @@ -215,7 +261,21 @@ targetMethod: MethodNode, mode: CollectVisitor.Mode, ): List> { - val constantInfos = getConstantInfos(annotation) ?: return emptyList() + val constantInfos = getConstantInfos(annotation) + if (constantInfos == null) { + val returnType = annotation.parentOfType()?.returnType + ?: return emptyList() + + val collectVisitor = ConstantInjectionPoint.MyCollectVisitor( + annotation.project, + mode, + null, + Type.getType(returnType.descriptor) + ) + collectVisitor.visit(targetMethod) + return collectVisitor.result.sortedBy { targetMethod.instructions.indexOf(it.insn) } + } + val constantInjectionPoint = InjectionPoint.byAtCode("CONSTANT") as? ConstantInjectionPoint ?: return emptyList() return constantInfos.asSequence().flatMap { modifyConstantInfo -> @@ -239,10 +299,28 @@ targetClass: ClassNode, targetMethod: MethodNode, ): InsnResolutionInfo.Failure? { - val constantInfos = getConstantInfos(annotation) ?: return InsnResolutionInfo.Failure() + val constantInfos = getConstantInfos(annotation) + if (constantInfos == null) { + val returnType = annotation.parentOfType()?.returnType + ?: return InsnResolutionInfo.Failure() + + val collectVisitor = ConstantInjectionPoint.MyCollectVisitor( + annotation.project, + CollectVisitor.Mode.MATCH_FIRST, + null, + Type.getType(returnType.descriptor) + ) + collectVisitor.visit(targetMethod) + return if (collectVisitor.result.isEmpty()) { + InsnResolutionInfo.Failure(collectVisitor.filterToBlame) + } else { + null + } + } + val constantInjectionPoint = InjectionPoint.byAtCode("CONSTANT") as? ConstantInjectionPoint ?: return null - return constantInfos.asSequence().mapNotNull { modifyConstantInfo -> + return constantInfos.firstNotNullOfOrNull { modifyConstantInfo -> val collectVisitor = ConstantInjectionPoint.MyCollectVisitor( annotation.project, CollectVisitor.Mode.MATCH_FIRST, @@ -255,12 +333,12 @@ ) collectVisitor.visit(targetMethod) if (collectVisitor.result.isEmpty()) { - collectVisitor.filterToBlame + InsnResolutionInfo.Failure(collectVisitor.filterToBlame) } else { null } - }.firstOrNull()?.let(InsnResolutionInfo::Failure) - } + } + } override fun isInsnAllowed(insn: AbstractInsnNode): Boolean { return insn.opcode in allowedOpcodes Index: src/main/kotlin/platform/mixin/handlers/injectionPoint/ConstantInjectionPoint.kt =================================================================== --- src/main/kotlin/platform/mixin/handlers/injectionPoint/ConstantInjectionPoint.kt (revision ea427f844a94df607d0f355fb932e40f98814406) +++ src/main/kotlin/platform/mixin/handlers/injectionPoint/ConstantInjectionPoint.kt (revision 152abf41a45014c351f14dae843050749be8df06) @@ -27,6 +27,7 @@ import com.demonwav.mcdev.util.ifNotBlank import com.intellij.codeInsight.lookup.LookupElementBuilder import com.intellij.openapi.project.Project +import com.intellij.psi.CommonClassNames import com.intellij.psi.JavaPsiFacade import com.intellij.psi.JavaTokenType import com.intellij.psi.PsiAnnotation @@ -103,9 +104,8 @@ at: PsiAnnotation, target: MixinSelector?, targetClass: PsiClass, - ): NavigationVisitor? { - val constantInfo = getConstantInfo(at) ?: return null - return MyNavigationVisitor(constantInfo) + ): NavigationVisitor { + return MyNavigationVisitor(getConstantInfo(at)) } override fun doCreateCollectVisitor( @@ -113,9 +113,8 @@ target: MixinSelector?, targetClass: ClassNode, mode: CollectVisitor.Mode, - ): CollectVisitor? { - val constantInfo = getConstantInfo(at) ?: return null - return MyCollectVisitor(at.project, mode, constantInfo) + ): CollectVisitor { + return MyCollectVisitor(at.project, mode, getConstantInfo(at)) } override fun createLookup( @@ -135,7 +134,8 @@ } class MyNavigationVisitor( - private val constantInfo: ConstantInfo, + private val constantInfo: ConstantInfo?, + private val expectedType: Type? = null, ) : NavigationVisitor() { override fun visitForeachStatement(statement: PsiForeachStatement) { if (statement.iteratedValue?.type is PsiArrayType) { @@ -169,10 +169,35 @@ } private fun visitConstant(element: PsiElement, value: Any?) { - if (value != constantInfo.constant) { + if (constantInfo != null && value != constantInfo.constant) { return } + if (expectedType != null && value != null) { + // First check if we expect any String literal + if (value is String && + (expectedType.sort != Type.OBJECT || expectedType.className != CommonClassNames.JAVA_LANG_STRING) + ) { + return + } + + // then check if we expect any class literal + if (value is Type && ( + expectedType.sort != Type.ARRAY && expectedType.sort != Type.OBJECT || + expectedType.className != CommonClassNames.JAVA_LANG_CLASS + ) + ) { + return + } + + // otherwise we expect a primitive literal + if (expectedType.sort in Type.BOOLEAN..Type.DOUBLE && + value::class.javaPrimitiveType?.let(Type::getType) != expectedType + ) { + return + } + } + val parent = PsiUtil.skipParenthesizedExprUp(element.parent) // check for expandZeroConditions @@ -189,7 +214,11 @@ JavaTokenType.GE -> Opcodes.IFGE else -> null } - if (opcode != null && !constantInfo.expandConditions.any { opcode in it.opcodes }) { + if (opcode != null && ( + constantInfo == null || + !constantInfo.expandConditions.any { opcode in it.opcodes } + ) + ) { return } } @@ -207,10 +236,10 @@ class MyCollectVisitor( private val project: Project, mode: Mode, - private val constantInfo: ConstantInfo, + private val constantInfo: ConstantInfo?, + private val expectedType: Type? = null, ) : CollectVisitor(mode) { override fun accept(methodNode: MethodNode) { - val elementFactory = JavaPsiFacade.getElementFactory(project) methodNode.instructions?.iterator()?.forEachRemaining { insn -> val constant = when (insn) { is InsnNode -> when (insn.opcode) { @@ -225,13 +254,15 @@ Opcodes.ACONST_NULL -> null else -> return@forEachRemaining } + is IntInsnNode -> when (insn.opcode) { Opcodes.BIPUSH, Opcodes.SIPUSH -> insn.operand else -> return@forEachRemaining } + is LdcInsnNode -> insn.cst is JumpInsnNode -> { - if (!constantInfo.expandConditions.any { insn.opcode in it.opcodes }) { + if (constantInfo == null || !constantInfo.expandConditions.any { insn.opcode in it.opcodes }) { return@forEachRemaining } var lastInsn = insn.previous @@ -251,17 +282,49 @@ } 0 } + else -> return@forEachRemaining } - if (constant == constantInfo.constant) { + + if (constantInfo != null && constant != constantInfo.constant) { + return@forEachRemaining + } + + if (expectedType != null && constant != null) { + // First check if we expect any String literal + if (constant is String && ( + expectedType.sort != Type.OBJECT || + expectedType.className != CommonClassNames.JAVA_LANG_STRING + ) + ) { + return@forEachRemaining + } + + // then check if we expect any class literal + if (constant is Type && ( + expectedType.sort != Type.ARRAY && expectedType.sort != Type.OBJECT || + expectedType.className != CommonClassNames.JAVA_LANG_CLASS + ) + ) { + return@forEachRemaining + } + + // otherwise we expect a primitive literal + if (expectedType.sort in Type.BOOLEAN..Type.DOUBLE && + constant::class.javaPrimitiveType?.let(Type::getType) != expectedType + ) { + return@forEachRemaining + } + } + + val elementFactory = JavaPsiFacade.getElementFactory(project) - val literal = if (constant is Type) { - elementFactory.createExpressionFromText("${constant.className}.class", null) - } else { - elementFactory.createLiteralExpression(constant) - } - addResult(insn, literal) - } - } - } - } + val literal = if (constant is Type) { + elementFactory.createExpressionFromText("${constant.className}.class", null) + } else { + elementFactory.createLiteralExpression(constant) + } + addResult(insn, literal) + } + } + } +} -}