From 7d65047b126c623dd91ff742ecbbdbaa0031b72d Mon Sep 17 00:00:00 2001 From: Karlatemp Date: Sat, 9 Jul 2022 20:23:38 +0800 Subject: [PATCH] `Root.MethodHandleLookup` --- .../karlatemp/unsafeaccessor/MHLookup.java | 69 +++++++ .../github/karlatemp/unsafeaccessor/Root.java | 151 +++++++++++++- .../BinaryCompatibilityAnalysis.java | 7 +- .../src/main/java/runtest/RunTestUnit.java | 11 ++ .../java/runtest/TestMethodHandleResolve.java | 186 ++++++++++++++++++ 5 files changed, 420 insertions(+), 4 deletions(-) create mode 100644 api/src/main/java/io/github/karlatemp/unsafeaccessor/MHLookup.java create mode 100644 impl/testunit/src/main/java/runtest/TestMethodHandleResolve.java diff --git a/api/src/main/java/io/github/karlatemp/unsafeaccessor/MHLookup.java b/api/src/main/java/io/github/karlatemp/unsafeaccessor/MHLookup.java new file mode 100644 index 0000000..8d8e5b1 --- /dev/null +++ b/api/src/main/java/io/github/karlatemp/unsafeaccessor/MHLookup.java @@ -0,0 +1,69 @@ +package io.github.karlatemp.unsafeaccessor; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; + +class MHLookup { + static MethodHandle lookupDirect(String name, MethodType type, boolean bind) throws NoSuchMethodException { + Object anyx = Unsafe.getUnsafe0().getOriginalUnsafe(); + MethodHandles.Lookup lookup = Root.RootLookupHolder.trustedIn(anyx.getClass()); + MethodHandle rsp; + try { + rsp = lookup.findVirtual(anyx.getClass(), name, type); + } catch (IllegalAccessException e) { + throw new InternalError(e); + } + return bind ? rsp.bindTo(anyx) : rsp; + } + + private static boolean checkUsingObj() { + Unsafe usf = Unsafe.getUnsafe0(); + if (!usf.isJava9()) return true; + return usf.getClass().getName().endsWith("Obj"); + } + + private static final boolean isUsingObj = checkUsingObj(); + + static MethodHandle lookup(String name, MethodType type, Object[] bind) throws NoSuchMethodException { + Unsafe usf = Unsafe.getUnsafe0(); + Object anyx = usf.getOriginalUnsafe(); + MethodHandles.Lookup lookup = Root.RootLookupHolder.trustedIn(anyx.getClass()); + String directName; + if (isUsingObj) { + directName = name.replace("Reference", "Object"); + } else directName = name; + + Object bindx = null; + MethodHandle rsp = null; + try { + try { + rsp = lookup.findVirtual(anyx.getClass(), directName, type); + bindx = anyx; + } catch (NoSuchMethodException ignored) { + if (!usf.isJava9()) { + // Opaque acquire Release + try { + String nwm = directName + .replace("Opaque", "Volatile") + .replace("Release", "Volatile") + .replace("Acquire", "Volatile"); + rsp = lookup.findVirtual(anyx.getClass(), directName, type); + bindx = anyx; + } catch (NoSuchMethodException ignored2) { + } + } + if (rsp == null) { + rsp = lookup.findVirtual(Unsafe.class, name, type); + bindx = Unsafe.getUnsafe0(); + } + } + } catch (IllegalAccessException e) { + throw new InternalError(e); + } + + if (bind == null) return rsp.bindTo(bindx); + bind[0] = bindx; + return rsp; + } +} diff --git a/api/src/main/java/io/github/karlatemp/unsafeaccessor/Root.java b/api/src/main/java/io/github/karlatemp/unsafeaccessor/Root.java index cec3e06..8de43e4 100644 --- a/api/src/main/java/io/github/karlatemp/unsafeaccessor/Root.java +++ b/api/src/main/java/io/github/karlatemp/unsafeaccessor/Root.java @@ -2,8 +2,7 @@ import org.jetbrains.annotations.Contract; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; +import java.lang.invoke.*; import java.lang.reflect.AccessibleObject; import java.lang.reflect.Field; import java.util.function.Consumer; @@ -205,4 +204,152 @@ public static void initializeObject(Object instance) { Unsafe.getUnsafe0().ensureClassInitialized(instance.getClass()); ObjectInitializer.initializer().accept(instance); } + + /** + * Lookup for method handles + * + * @since 1.7.0 + */ + @SuppressWarnings("UnusedReturnValue") + public static class MethodHandleLookup { + private static void checkAccess(UnsafeAccess access) { + if (access == null) { + getUnsafe(); + } else { + access.checkTrusted(); + } + } + + public static MethodHandle lookup( + UnsafeAccess access, + String methodName, + MethodType methodType + ) throws NoSuchMethodException { + return lookup(access, methodName, methodType, false, true); + } + + /** + * Search a method handle from unsafe instance + * + * @param access The unsafe access object. No perm checking when access provided + * @param methodName The name of target method. Eg: {@code "getInt"} + * @param methodType The method type of target method.
Eg: {@code MethodType.methodType(int.class, long.class)} + * @param lookupDirect Do the direct search. + * @param doBind Bind the unsafe object to method handle. Provide `false` to run {@link MethodHandles.Lookup#revealDirect(MethodHandle)} + */ + public static MethodHandle lookup( + UnsafeAccess access, + String methodName, + MethodType methodType, + boolean lookupDirect, + boolean doBind + ) throws NoSuchMethodException { + checkAccess(access); + if (lookupDirect) { + return MHLookup.lookupDirect(methodName, methodType, doBind); + } else { + return MHLookup.lookup(methodName, methodType, doBind ? null : new Object[1]); + } + } + + private static UnsafeAccess detectUnsafeAccessHold(MethodHandles.Lookup lookup, MethodHandle mh) { + if (mh != null) { + if (mh.type().parameterCount() != 0) { + throw new IllegalArgumentException("parameters not empty: " + mh); + } + if (mh.type().returnType() != UnsafeAccess.class) { + throw new IllegalArgumentException("Provided method handle is not a access check handle."); + } + try { + return (UnsafeAccess) mh.invokeExact(); + } catch (Error | RuntimeException e) { + throw e; + } catch (Throwable throwable) { + throw new InternalError(throwable); + } + } + if (lookup == null) return null; + if (lookup == MethodHandles.publicLookup()) return null; + try { + try { + return detectUnsafeAccessHold(null, lookup.findStaticGetter( + lookup.lookupClass(), "UA", UnsafeAccess.class + )); + } catch (NoSuchFieldException ignored) { + } + try { + return detectUnsafeAccessHold(null, lookup.findStaticGetter( + lookup.lookupClass(), "UNSAFE_ACCESS", UnsafeAccess.class + )); + } catch (NoSuchFieldException ignored) { + } + } catch (IllegalAccessException ignored) { + throw new IllegalArgumentException("Provided caller <" + lookup + "> have no full access for itself"); + } + return null; + } + + public static CallSite resolveHandle( + MethodHandles.Lookup caller, + String methodName, + MethodType methodType + ) throws NoSuchMethodException { + return resolve(caller, methodName, methodType, 0, null); + } + + public static CallSite resolveHandleDirect( + MethodHandles.Lookup caller, + String methodName, + MethodType methodType + ) throws NoSuchMethodException { + return resolve(caller, methodName, methodType, 1, null); + } + + public static CallSite resolve( + MethodHandles.Lookup caller, + String methodName, + MethodType methodType, + int direct, + MethodHandle unsafeAccess_static_getter + ) throws NoSuchMethodException { + return resolve(caller, methodName, methodType, direct, 0, unsafeAccess_static_getter); + } + + /** + * Resolve a method handle for `invokeDynamic` + *

+ * Example:

{@code
+         * mymethod.visitInvokeDynamicInsn("getInt", "(J)I", new Handle(
+         *         Opcodes.H_INVOKESTATIC,
+         *         "io/github/karlatemp/unsafeaccessor/Root$MethodHandleLookup",
+         *         "resolve", "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;IILjava/lang/invoke/MethodHandle;)Ljava/lang/invoke/CallSite;", false
+         * ), 0, 0, new Handle(
+         *         Opcodes.H_GETSTATIC,
+         *         "org/example/generated/MyClass", "UNSAFE_ACCESS", "Lio/github/karlatemp/unsafeaccessor/UnsafeAccess;",
+         *         false
+         * ));
+         * }
+ * + * @param noCallerCheck if {@code true}, will skip caller permission detect + * @param unsafeAccess_static_getter The handle to get an unsafe-access instance.
+ */ + public static CallSite resolve( + MethodHandles.Lookup caller, + String methodName, + MethodType methodType, + int direct, + int noCallerCheck, + MethodHandle unsafeAccess_static_getter + ) throws NoSuchMethodException { + return new ConstantCallSite( + lookup(detectUnsafeAccessHold(icb(noCallerCheck) ? null : caller, unsafeAccess_static_getter), + methodName, methodType, icb(direct), true + ) + ); + } + + private static boolean icb(int v) { + return v != 0; + } + } } diff --git a/impl/testunit/src/main/java/io/github/karlatemp/unsafeaccessor/BinaryCompatibilityAnalysis.java b/impl/testunit/src/main/java/io/github/karlatemp/unsafeaccessor/BinaryCompatibilityAnalysis.java index becba12..fa58a00 100644 --- a/impl/testunit/src/main/java/io/github/karlatemp/unsafeaccessor/BinaryCompatibilityAnalysis.java +++ b/impl/testunit/src/main/java/io/github/karlatemp/unsafeaccessor/BinaryCompatibilityAnalysis.java @@ -14,7 +14,6 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.ArrayList; -import java.util.Collection; import java.util.List; @TestTask(name = "BinaryCompatibilityAnalysis") @@ -121,7 +120,11 @@ private static void analyze(Class targetClass) throws Throwable { lookup.findConstructor(ownerClass, methodType); break; } - throw new AssertionError("INVOKESPECIAL with-out "); + MethodHandle handle = lookup.findVirtual(ownerClass, methodName, methodType); + MethodHandleInfo handleInfo = lookup.revealDirect(handle); + Assertions.assertEquals(handleInfo.getDeclaringClass(), ownerClass); + Assertions.assertEquals(handleInfo.getReferenceKind(), MethodHandleInfo.REF_invokeVirtual); + break; } default: throw new AssertionError("Unknown opcode: " + methodInsnNode.getOpcode() + "(" + Integer.toHexString(methodInsnNode.getOpcode()) + ")"); diff --git a/impl/testunit/src/main/java/runtest/RunTestUnit.java b/impl/testunit/src/main/java/runtest/RunTestUnit.java index 844bd46..f8d5bae 100644 --- a/impl/testunit/src/main/java/runtest/RunTestUnit.java +++ b/impl/testunit/src/main/java/runtest/RunTestUnit.java @@ -1,5 +1,8 @@ package runtest; +import io.github.karlatemp.unsafeaccessor.Unsafe; +import org.objectweb.asm.ClassWriter; + import java.util.ArrayList; import java.util.List; @@ -11,4 +14,12 @@ public static void main(String[] args) throws Throwable { classes.sort(String::compareTo); TestTasks.runTests(classes); } + + public static Class define(ClassWriter cw) { + return define(cw.toByteArray()); + } + + public static Class define(byte[] code) { + return Unsafe.getUnsafe().defineClass(null, code, 0, code.length, ClassLoader.getSystemClassLoader(), null); + } } diff --git a/impl/testunit/src/main/java/runtest/TestMethodHandleResolve.java b/impl/testunit/src/main/java/runtest/TestMethodHandleResolve.java new file mode 100644 index 0000000..c556e8f --- /dev/null +++ b/impl/testunit/src/main/java/runtest/TestMethodHandleResolve.java @@ -0,0 +1,186 @@ +package runtest; + +import io.github.karlatemp.unsafeaccessor.Root; +import io.github.karlatemp.unsafeaccessor.Unsafe; +import io.github.karlatemp.unsafeaccessor.UnsafeAccess; +import org.junit.jupiter.api.Assertions; +import org.objectweb.asm.ClassWriter; +import org.objectweb.asm.Handle; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; + +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + +public class TestMethodHandleResolve { + MethodType mhType(boolean oarg, String actType, Class type) { + if (actType.equals("get")) { + if (oarg) { + return MethodType.methodType(type, Object.class, long.class); + } else { + return MethodType.methodType(type, long.class); + } + } else { + if (oarg) { + return MethodType.methodType(void.class, Object.class, long.class, type); + } else { + return MethodType.methodType(void.class, long.class, type); + } + } + } + + @TestTask + void run() throws Throwable { + String[] actTypes = {"get", "put"}; + Map> typeMap = new HashMap<>(); + { + Object[] pvMap = { + "Byte", /**/byte.class, + "Char", /**/char.class, + "Short", /**/short.class, + "Int", /**/int.class, + "Long", /**/long.class, + "Float", /**/float.class, + "Double", /**/double.class, + "Boolean", /**/boolean.class, + "Reference", /**/Object.class, + }; + for (int i = 0; i < pvMap.length; i += 2) { + typeMap.put(pvMap[i].toString(), (Class) pvMap[i + 1]); + } + } + Collection> noDirectAddrSet = Arrays.asList(boolean.class, Object.class); + String[] privTypes = typeMap.keySet().toArray(new String[0]); + String[] primTypes = { + "Byte", "Char", "Short", "Int", "Long", + "Float", "Double", + "Boolean", + }; + String[] unalignedTypes = { + "Char", "Short", "Int", "Long", + }; + UnsafeAccess ua = UnsafeAccess.getInstance(); + for (String actType : actTypes) { + for (String privType : privTypes) { + String methodName = actType + privType; + if (!noDirectAddrSet.contains(typeMap.get(privType))) { + Root.MethodHandleLookup.lookup(ua, methodName, mhType( + false, actType, typeMap.get(privType) + )); + } + Root.MethodHandleLookup.lookup(ua, methodName, mhType( + true, actType, typeMap.get(privType) + )); + } + for (String primType : primTypes) { + String methodName = actType + primType; + if (!noDirectAddrSet.contains(typeMap.get(primType))) { + Root.MethodHandleLookup.lookup(ua, methodName, mhType( + false, actType, typeMap.get(primType) + ), true, false); + } + Root.MethodHandleLookup.lookup(ua, methodName, mhType( + true, actType, typeMap.get(primType) + ), true, false); + } + } + + // all unsafe methods should be able to be resolve + for (Method m : Unsafe.class.getDeclaredMethods()) { + if (Modifier.isStatic(m.getModifiers())) continue; + if (!Modifier.isPublic(m.getModifiers())) continue; + Root.MethodHandleLookup.lookup(ua, + m.getName(), MethodType.methodType(m.getReturnType(), m.getParameterTypes()), + false, false + ); + } + } + + @TestTask + void testInvokeDynamic() throws Throwable { + ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_FRAMES); + writer.visit(Opcodes.V1_8, Opcodes.ACC_PUBLIC, "sawet/sda22te2/A23fx", null, "java/lang/Object", null); + + writer.visitField(Opcodes.ACC_STATIC, "UNSAFE_ACCESS", "Lio/github/karlatemp/unsafeaccessor/UnsafeAccess;", null, null); + MethodVisitor mymethod; + + mymethod = writer.visitMethod(Opcodes.ACC_STATIC, "test1", "(J)I", null, null); + mymethod.visitVarInsn(Opcodes.LLOAD, 0); + mymethod.visitInvokeDynamicInsn("getInt", "(J)I", + new Handle(Opcodes.H_INVOKESTATIC, + "io/github/karlatemp/unsafeaccessor/Root$MethodHandleLookup", + "resolve", "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;IILjava/lang/invoke/MethodHandle;)Ljava/lang/invoke/CallSite;", + false + ), + false, false, + new Handle(Opcodes.H_GETSTATIC, + "sawet/sda22te2/A23fx", + "UNSAFE_ACCESS", "Lio/github/karlatemp/unsafeaccessor/UnsafeAccess;", + false + ) + ); + mymethod.visitInsn(Opcodes.IRETURN); + mymethod.visitMaxs(0, 0); + + + mymethod = writer.visitMethod(Opcodes.ACC_STATIC, "test2", "(J)I", null, null); + mymethod.visitVarInsn(Opcodes.LLOAD, 0); + mymethod.visitInvokeDynamicInsn("getInt", "(J)I", + new Handle(Opcodes.H_INVOKESTATIC, + "io/github/karlatemp/unsafeaccessor/Root$MethodHandleLookup", + "resolveHandle", "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;", + false + ) + ); + mymethod.visitInsn(Opcodes.IRETURN); + mymethod.visitMaxs(0, 0); + + + mymethod = writer.visitMethod(Opcodes.ACC_STATIC, "test3", "(J)I", null, null); + mymethod.visitVarInsn(Opcodes.LLOAD, 0); + mymethod.visitInvokeDynamicInsn("getInt", "(J)I", + new Handle(Opcodes.H_INVOKESTATIC, + "io/github/karlatemp/unsafeaccessor/Root$MethodHandleLookup", + "resolve", "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;ILjava/lang/invoke/MethodHandle;)Ljava/lang/invoke/CallSite;", + false + ), + false, + new Handle(Opcodes.H_GETSTATIC, + "sawet/sda22te2/A23fx", + "UNSAFE_ACCESS", "Lio/github/karlatemp/unsafeaccessor/UnsafeAccess;", + false + ) + ); + mymethod.visitInsn(Opcodes.IRETURN); + mymethod.visitMaxs(0, 0); + + + + Class klass = RunTestUnit.define(writer); + Unsafe usf = Unsafe.getUnsafe(); + MethodHandles.Lookup lk = Root.getTrusted(klass); + long sizex = usf.allocateMemory(Integer.SIZE); + int rsp; + + usf.putInt(sizex, 114); + rsp = (int) lk.findStatic(klass, "test1", MethodType.methodType(int.class, long.class)).invoke(sizex); + Assertions.assertEquals(114, rsp); + + + usf.putInt(sizex, 514); + rsp = (int) lk.findStatic(klass, "test2", MethodType.methodType(int.class, long.class)).invoke(sizex); + Assertions.assertEquals(514, rsp); + + usf.putInt(sizex, 1919); + rsp = (int) lk.findStatic(klass, "test3", MethodType.methodType(int.class, long.class)).invoke(sizex); + Assertions.assertEquals(1919, rsp); + + + } +}