[FIX] InjectClassAcceptor fixes (by radioegor146)

This commit is contained in:
Gravit 2020-01-30 17:43:07 +07:00
parent 83a3c963d3
commit d813e714da
No known key found for this signature in database
GPG key ID: 061981E1E85D3216
3 changed files with 204 additions and 166 deletions

View file

@ -71,7 +71,7 @@ task cleanjar(type: Jar, dependsOn: jar) {
dependencies {
pack project(':LauncherAPI')
bundle 'org.ow2.asm:asm-commons:7.1'
bundle 'org.ow2.asm:asm-commons:7.3.1'
bundle 'mysql:mysql-connector-java:8.0.16'
bundle 'org.postgresql:postgresql:42.2.6'
bundle 'org.jline:jline:3.13.1'

View file

@ -1,26 +1,13 @@
package pro.gravit.launchserver.asm;
import java.util.Arrays;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.objectweb.asm.AnnotationVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AnnotationNode;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.FieldInsnNode;
import org.objectweb.asm.tree.FieldNode;
import org.objectweb.asm.tree.InsnList;
import org.objectweb.asm.tree.InsnNode;
import org.objectweb.asm.tree.LdcInsnNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.TypeInsnNode;
import org.objectweb.asm.tree.VarInsnNode;
import org.objectweb.asm.tree.*;
import pro.gravit.launchserver.binary.BuildContext;
import pro.gravit.launchserver.binary.tasks.MainBuildTask;
@ -33,167 +20,197 @@ public class InjectClassAcceptor implements MainBuildTask.ASMTransformer {
public InjectClassAcceptor(Map<String, Object> values) {
this.values = values;
}
private static final List<Class<?>> zPrimitivesList = Arrays.asList(java.lang.Boolean.class, java.lang.Character.class,
private static final List<Class<?>> primitiveClasses = Arrays.asList(java.lang.Boolean.class, java.lang.Character.class,
java.lang.Byte.class, java.lang.Short.class, java.lang.Integer.class, java.lang.Long.class,
java.lang.Float.class, java.lang.Double.class, java.lang.String.class);
private static final String INJ_DESC = Type.getDescriptor(LauncherInject.class);
private static final String INJ_C_DESC = Type.getDescriptor(LauncherInjectionConstructor.class);
private static final List<String> cPrimitivesList = Arrays.asList("I", "V", "Z", "B", "C", "S", "D", "F", "J", Type.getDescriptor(String.class));
private static void visit(ClassNode cn, Map<String, Object> object) {
MethodNode clinit = cn.methods.stream().filter(methodNode ->
"<clinit>".equals(methodNode.name)).findFirst().orElseGet(() -> {
MethodNode ret = new MethodNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_STATIC | Opcodes.ACC_SYNTHETIC, "<clinit>", "()V", null, null);
ret.instructions.add(new InsnNode(Opcodes.RETURN));
cn.methods.add(ret);
return ret;
});
List<MethodNode> constructors = cn.methods.stream().filter(e -> "<init>".equals(e.name)).collect(Collectors.toList());
MethodNode init = constructors.stream().filter(e -> e != null && e.invisibleAnnotations != null && e.invisibleAnnotations.stream().anyMatch(f -> INJ_C_DESC.equals(f.desc))).findFirst()
.orElseGet(() -> constructors.stream().filter(e -> e.desc.equals("()V")).findFirst().orElse(null));
cn.fields.stream().filter(e -> e.invisibleAnnotations != null)
.filter(e -> !e.invisibleAnnotations.isEmpty() && e.invisibleAnnotations.stream().anyMatch(f -> f.desc.equals(INJ_DESC))).forEach(e -> {
// Notice that fields that will be used with this algo should not have default
// value by = ...;
AnnotationNode n = e.invisibleAnnotations.stream().filter(f -> INJ_DESC.equals(f.desc)).findFirst()
.get();
e.invisibleAnnotations.removeIf(f -> INJ_DESC.equals(f.desc));
AtomicReference<String> valueName = new AtomicReference<>(null);
n.accept(new AnnotationVisitor(Opcodes.ASM7) {
@Override
public void visit(final String name, final Object value) {
if ("value".equals(name)) {
if (value.getClass() != String.class)
throw new IllegalArgumentException(
"Invalid Annotation with value class " + e.getClass().getName());
valueName.set(value.toString());
}
}
});
if (valueName.get() == null)
throw new IllegalArgumentException("Annotation should always contains 'value' key");
if (object.containsKey(valueName.get())) {
Object val = object.get(valueName.get());
if ((e.access & Opcodes.ACC_STATIC) != 0)
if (cPrimitivesList.contains(e.desc) && zPrimitivesList.contains(val.getClass()))
e.value = val;
else {
List<FieldInsnNode> nodes = Arrays.stream(clinit.instructions.toArray()).filter(p -> p instanceof FieldInsnNode && p.getOpcode() == Opcodes.PUTSTATIC).map(p -> (FieldInsnNode) p)
.filter(p -> p.owner.equals(cn.name) && p.name.equals(e.name) && p.desc.equals(e.desc)).collect(Collectors.toList());
InsnList injector = new InsnList();
pushInjector(injector, val, e);
if (nodes.isEmpty()) {
injector.insert(new InsnNode(Opcodes.ICONST_0));
injector.add(new FieldInsnNode(Opcodes.PUTSTATIC, cn.name, e.name, e.desc));
Arrays.stream(clinit.instructions.toArray()).filter(p -> p.getOpcode() == Opcodes.RETURN).forEach(p -> clinit.instructions.insertBefore(p, injector));
} else
for (FieldInsnNode node : nodes) clinit.instructions.insertBefore(node, injector);
}
else {
if (init == null) throw new IllegalArgumentException("Not found init in target: " + cn.name);
List<FieldInsnNode> nodes = Arrays.stream(init.instructions.toArray()).filter(p -> p instanceof FieldInsnNode && p.getOpcode() == Opcodes.PUTFIELD).map(p -> (FieldInsnNode) p)
.filter(p -> p.owner.equals(cn.name) && p.name.equals(e.name) && p.desc.equals(e.desc)).collect(Collectors.toList());
InsnList injector = new InsnList();
pushInjector(injector, val, e);
if (nodes.isEmpty()) {
injector.insert(new VarInsnNode(Opcodes.ALOAD, 0));
injector.insert(new InsnNode(Opcodes.ICONST_0));
injector.add(new FieldInsnNode(Opcodes.PUTSTATIC, cn.name, e.name, e.desc));
Arrays.stream(init.instructions.toArray()).filter(p -> p.getOpcode() == Opcodes.RETURN).forEach(p -> clinit.instructions.insertBefore(p, injector));
} else
for (FieldInsnNode node : nodes) init.instructions.insertBefore(node, injector);
}
}
private static final String INJECTED_FIELD_DESC = Type.getDescriptor(LauncherInject.class);
private static final String INJECTED_CONSTRUCTOR_DESC = Type.getDescriptor(LauncherInjectionConstructor.class);
private static final List<String> primitiveDescriptors = Arrays.asList(Type.INT_TYPE.getDescriptor(),
Type.VOID_TYPE.getDescriptor(), Type.BOOLEAN_TYPE.getDescriptor(), Type.BYTE_TYPE.getDescriptor(),
Type.CHAR_TYPE.getDescriptor(), Type.SHORT_TYPE.getDescriptor(), Type.DOUBLE_TYPE.getDescriptor(),
Type.FLOAT_TYPE.getDescriptor(), Type.LONG_TYPE.getDescriptor(), Type.getDescriptor(String.class));
private static void visit(ClassNode classNode, Map<String, Object> values) {
MethodNode clinitMethod = classNode.methods.stream().filter(methodNode -> "<clinit>".equals(methodNode.name))
.findFirst().orElseGet(() -> {
MethodNode newClinitMethod = new MethodNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_STATIC | Opcodes.ACC_SYNTHETIC,
"<clinit>", "()V", null, null);
newClinitMethod.instructions.add(new InsnNode(Opcodes.RETURN));
classNode.methods.add(newClinitMethod);
return newClinitMethod;
});
List<MethodNode> constructors = classNode.methods.stream().filter(method -> "<init>".equals(method.name))
.collect(Collectors.toList());
MethodNode initMethod = constructors.stream().filter(method -> method.invisibleAnnotations != null
&& method.invisibleAnnotations.stream().anyMatch(annotation -> INJECTED_CONSTRUCTOR_DESC.equals(annotation.desc))).findFirst()
.orElseGet(() -> constructors.stream().filter(method -> method.desc.equals("()V")).findFirst().orElse(null));
classNode.fields.forEach(field -> {
// Notice that fields that will be used with this algo should not have default
// value by = ...;
AnnotationNode valueAnnotation = field.invisibleAnnotations.stream()
.filter(annotation -> INJECTED_FIELD_DESC.equals(annotation.desc)).findFirst()
.orElse(null);
if (valueAnnotation == null) {
return;
}
field.invisibleAnnotations.remove(valueAnnotation);
AtomicReference<String> valueName = new AtomicReference<String>(null);
valueAnnotation.accept(new AnnotationVisitor(Opcodes.ASM7) {
@Override
public void visit(final String name, final Object value) {
if ("value".equals(name)) {
if (value.getClass() != String.class)
throw new IllegalArgumentException(
String.format("Invalid annotation with value class %s", field.getClass().getName()));
valueName.set(value.toString());
}
}
});
if (valueName.get() == null) {
throw new IllegalArgumentException("Annotation should always contains 'value' key");
}
if (!values.containsKey(valueName.get())) {
return;
}
Object value = values.get(valueName.get());
if ((field.access & Opcodes.ACC_STATIC) != 0) {
if (primitiveDescriptors.contains(field.desc) && primitiveClasses.contains(value.getClass())) {
field.value = value;
return;
}
List<FieldInsnNode> putStaticNodes = Arrays.stream(clinitMethod.instructions.toArray())
.filter(node -> node instanceof FieldInsnNode && node.getOpcode() == Opcodes.PUTSTATIC).map(p -> (FieldInsnNode) p)
.filter(node -> node.owner.equals(classNode.name) && node.name.equals(field.name) && node.desc.equals(field.desc)).collect(Collectors.toList());
InsnList setter = serializeValue(value);
if (putStaticNodes.isEmpty()) {
setter.add(new FieldInsnNode(Opcodes.PUTSTATIC, classNode.name, field.name, field.desc));
Arrays.stream(clinitMethod.instructions.toArray()).filter(node -> node.getOpcode() == Opcodes.RETURN)
.forEach(node -> clinitMethod.instructions.insertBefore(node, setter));
} else {
setter.insert(new InsnNode(Type.getType(field.desc).getSize() == 1 ? Opcodes.POP : Opcodes.POP2));
for (FieldInsnNode fieldInsnNode : putStaticNodes) {
clinitMethod.instructions.insertBefore(fieldInsnNode, setter);
}
}
} else {
if (initMethod == null) {
throw new IllegalArgumentException(String.format("Not found init in target: %s", classNode.name));
}
List<FieldInsnNode> putFieldNodes = Arrays.stream(initMethod.instructions.toArray())
.filter(node -> node instanceof FieldInsnNode && node.getOpcode() == Opcodes.PUTFIELD).map(p -> (FieldInsnNode) p)
.filter(node -> node.owner.equals(classNode.name) && node.name.equals(field.name) && node.desc.equals(field.desc)).collect(Collectors.toList());
InsnList setter = serializeValue(value);
if (putFieldNodes.isEmpty()) {
setter.insert(new VarInsnNode(Opcodes.ALOAD, 0));
setter.add(new FieldInsnNode(Opcodes.PUTFIELD, classNode.name, field.name, field.desc));
Arrays.stream(initMethod.instructions.toArray())
.filter(node -> node.getOpcode() == Opcodes.RETURN)
.forEach(node -> initMethod.instructions.insertBefore(node, setter));
} else {
setter.insert(new InsnNode(Type.getType(field.desc).getSize() == 1 ? Opcodes.POP : Opcodes.POP2));
for (FieldInsnNode fieldInsnNode : putFieldNodes) {
initMethod.instructions.insertBefore(fieldInsnNode, setter);
}
}
}
});
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private static void pushInjector(InsnList injector, Object val, FieldNode e) {
injector.add(new InsnNode(Opcodes.POP));
if (e.desc.equals("Z")) {
if ((Boolean) val) injector.add(new InsnNode(Opcodes.ICONST_1));
else injector.add(new InsnNode(Opcodes.ICONST_0));
} else if (e.desc.equals("C")) {
injector.add(NodeUtils.push(((Number) val).intValue()));
injector.add(new InsnNode(Opcodes.I2C));
} else if (e.desc.equals("B")) {
injector.add(NodeUtils.push(((Number) val).intValue()));
injector.add(new InsnNode(Opcodes.I2B));
} else if (e.desc.equals("S")) {
injector.add(NodeUtils.push(((Number) val).intValue()));
injector.add(new InsnNode(Opcodes.I2S));
} else if (e.desc.equals("I")) {
injector.add(NodeUtils.push(((Number) val).intValue()));
} else if (e.desc.equals("[B")) {
serializebArr(injector, (byte[]) val);
} else if (e.desc.equals("Ljava/util/List;")) {
if (((List) val).isEmpty()) {
injector.add(new TypeInsnNode(Opcodes.NEW, "java/util/ArrayList"));
injector.add(new InsnNode(Opcodes.DUP));
injector.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, "java/util/ArrayList", "<init>", "()V"));
} else {
Class<?> c = ((List) val).get(0).getClass();
if (c == byte[].class)
serializeListbArr(injector, (List<byte[]>) val);
else if (c == String.class)
serializeListString(injector, (List<String>) val);
else
throw new UnsupportedOperationException("Unsupported class" + c.getName());
private static Map<Class<?>, Serializer<?>> serializers = new HashMap<>();
static {
serializers.put(List.class, new ListSerializer());
serializers.put(Map.class, new MapSerializer());
serializers.put(byte[].class, new ByteArraySerializer());
}
private interface Serializer<T> {
InsnList serialize(T value);
}
@SuppressWarnings("unchecked")
private static InsnList serializeValue(Object value) {
if (value == null) {
InsnList insnList = new InsnList();
insnList.add(new InsnNode(Opcodes.ACONST_NULL));
return insnList;
}
if (primitiveClasses.contains(value.getClass())) {
InsnList insnList = new InsnList();
insnList.add(new LdcInsnNode(value));
return insnList;
}
for (Map.Entry<Class<?>, Serializer<?>> serializerEntry : serializers.entrySet()) {
if (serializerEntry.getKey().isInstance(value)) {
return ((Serializer) serializerEntry.getValue()).serialize(value);
}
} else if (e.desc.equals("Ljava/util/Map;")) {
serializeMap(injector, (Map<String, String>) val);
} else {
if (!cPrimitivesList.contains(e.desc) || !zPrimitivesList.contains(val.getClass()))
throw new UnsupportedOperationException("Unsupported class");
injector.add(new LdcInsnNode(val));
}
throw new UnsupportedOperationException(String.format("Serialization of type %s is not supported",
value.getClass()));
}
private static class ListSerializer implements Serializer<List> {
@Override
public InsnList serialize(List value) {
InsnList insnList = new InsnList();
insnList.add(new TypeInsnNode(Opcodes.NEW, Type.getInternalName(ArrayList.class)));
insnList.add(new InsnNode(Opcodes.DUP));
insnList.add(NodeUtils.push(value.size()));
insnList.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, Type.getInternalName(ArrayList.class), "<init>",
Type.getMethodDescriptor(Type.VOID_TYPE, Type.INT_TYPE), false));
for (Object object : value) {
insnList.add(new InsnNode(Opcodes.DUP));
insnList.add(serializeValue(object));
insnList.add(new MethodInsnNode(Opcodes.INVOKEINTERFACE, Type.getInternalName(List.class), "add",
Type.getMethodDescriptor(Type.BOOLEAN_TYPE, Type.getType(Object.class)), true));
insnList.add(new InsnNode(Opcodes.POP));
}
return insnList;
}
}
private static void serializeMap(InsnList inj, Map<String, String> map) {
inj.add(new TypeInsnNode(Opcodes.NEW, "java/util/HashMap"));
inj.add(new InsnNode(Opcodes.DUP)); // +1
inj.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, "java/util/HashMap", "<init>", "()V"));
map.forEach((k, v) -> {
inj.add(new InsnNode(Opcodes.DUP)); // +1-1
inj.add(NodeUtils.getSafeStringInsnList(k));
inj.add(NodeUtils.getSafeStringInsnList(v));
inj.add(new MethodInsnNode(Opcodes.INVOKEINTERFACE, "java/util/Map", "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", true));
inj.add(new InsnNode(Opcodes.POP));
});
private static class MapSerializer implements Serializer<Map> {
@Override
public InsnList serialize(Map value) {
InsnList insnList = new InsnList();
insnList.add(new TypeInsnNode(Opcodes.NEW, Type.getInternalName(value.getClass())));
insnList.add(new InsnNode(Opcodes.DUP));
insnList.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, Type.getInternalName(value.getClass()), "<init>",
Type.getMethodDescriptor(Type.VOID_TYPE), false));
for (Object entryObject : value.entrySet()) {
Map.Entry entry = (Map.Entry) entryObject;
insnList.add(new InsnNode(Opcodes.DUP));
insnList.add(serializeValue(entry.getKey()));
insnList.add(serializeValue(entry.getValue()));
insnList.add(new MethodInsnNode(Opcodes.INVOKEINTERFACE, Type.getInternalName(Map.class), "put",
Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(Object.class), Type.getType(Object.class)),
true));
insnList.add(new InsnNode(Opcodes.POP));
}
return insnList;
}
}
private static void serializebArr(InsnList injector, byte[] val) {
injector.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "java/util/Base64", "getDecoder", "()Ljava/util/Base64$Decoder;", false));
injector.add(NodeUtils.getSafeStringInsnList(Base64.getEncoder().encodeToString(val)));
injector.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, "java/util/Base64$Decoder", "decode", "(Ljava/lang/String;)[B", false));
}
private static class ByteArraySerializer implements Serializer<byte[]> {
private static void serializeListbArr(InsnList inj, List<byte[]> val) {
inj.add(new TypeInsnNode(Opcodes.NEW, "java/util/ArrayList"));
inj.add(new InsnNode(Opcodes.DUP)); // +1
inj.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, "java/util/ArrayList", "<init>", "()V"));
for (byte[] value : val) {
inj.add(new InsnNode(Opcodes.DUP)); // +1-1
serializebArr(inj, value);
inj.add(new MethodInsnNode(Opcodes.INVOKEINTERFACE, "java/util/List", "add", "(Ljava/lang/Object;)Z", true));
inj.add(new InsnNode(Opcodes.POP));
}
}
private static void serializeListString(InsnList inj, List<String> val) {
inj.add(new TypeInsnNode(Opcodes.NEW, "java/util/ArrayList"));
inj.add(new InsnNode(Opcodes.DUP)); // +1
inj.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, "java/util/ArrayList", "<init>", "()V"));
for (String value : val) {
inj.add(new InsnNode(Opcodes.DUP)); // +1-1
inj.add(NodeUtils.getSafeStringInsnList(value));
inj.add(new MethodInsnNode(Opcodes.INVOKEINTERFACE, "java/util/List", "add", "(Ljava/lang/Object;)Z", true));
inj.add(new InsnNode(Opcodes.POP));
}
@Override
public InsnList serialize(byte[] value) {
InsnList insnList = new InsnList();
insnList.add(new MethodInsnNode(Opcodes.INVOKESTATIC, Type.getInternalName(Base64.class),
"getDecoder", Type.getMethodDescriptor(Type.getType(Base64.Decoder.class)), false));
insnList.add(NodeUtils.getSafeStringInsnList(Base64.getEncoder().encodeToString(value)));
insnList.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Base64.Decoder.class),
"decode", Type.getMethodDescriptor(Type.getType(byte[].class), Type.getType(String.class)),
false));
return insnList;
}
}
@Override
public void transform(ClassNode cn, String classname, BuildContext context) {
visit(cn, values);
public void transform(ClassNode classNode, String className, BuildContext context) {
visit(classNode, values);
}
}
}

View file

@ -16,7 +16,9 @@
import java.io.InputStream;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class ASMTransformersTest {
@ -34,7 +36,11 @@ public void rawDefineClass(String name, byte[] bytes, int offset, int length)
public static class TestClass
{
@LauncherInject(value = "testprop")
public static int test = 1;
public int test;
@LauncherInject(value = "testprop2")
public List<String> s;
@LauncherInject(value = "testprop3")
public Map<String, String> map;
}
@BeforeAll
public static void prepare() throws Exception {
@ -49,6 +55,14 @@ void testASM() throws Exception
node.name = "ASMTestClass";
Map<String, Object> map = new HashMap<>();
map.put("testprop", 1234);
List<String> strings = new ArrayList<>();
strings.add("a");
strings.add("b");
map.put("testprop2", strings);
Map<String, String> byteMap = new HashMap<>();
byteMap.put("a", "TEST A");
byteMap.put("b", "TEST B");
map.put("testprop3", byteMap);
InjectClassAcceptor injectClassAcceptor = new InjectClassAcceptor(map);
injectClassAcceptor.transform(node, "ASMTestClass", null);
ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
@ -56,8 +70,15 @@ void testASM() throws Exception
byte[] bytes = writer.toByteArray();
classLoader.rawDefineClass("ASMTestClass", bytes, 0, bytes.length);
Class<?> clazz = classLoader.loadClass("ASMTestClass");
Object instance = clazz.newInstance();
Field field = clazz.getField("test");
Object result = field.get(null);
Object result = field.get(instance);
Assertions.assertEquals(1234, (int) (Integer) result);
field = clazz.getField("s");
result = field.get(instance);
Assertions.assertEquals(strings, result);
field = clazz.getField("map");
result = field.get(instance);
Assertions.assertEquals(byteMap, result);
}
}