diff options
Diffstat (limited to 'ARMeilleure/Signal')
| -rw-r--r-- | ARMeilleure/Signal/NativeSignalHandler.cs | 21 | ||||
| -rw-r--r-- | ARMeilleure/Signal/TestMethods.cs | 84 | ||||
| -rw-r--r-- | ARMeilleure/Signal/WindowsPartialUnmapHandler.cs | 186 | ||||
| -rw-r--r-- | ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs | 23 |
4 files changed, 311 insertions, 3 deletions
diff --git a/ARMeilleure/Signal/NativeSignalHandler.cs b/ARMeilleure/Signal/NativeSignalHandler.cs index cad0d420..0257f440 100644 --- a/ARMeilleure/Signal/NativeSignalHandler.cs +++ b/ARMeilleure/Signal/NativeSignalHandler.cs @@ -197,12 +197,29 @@ namespace ARMeilleure.Signal // Only call tracking if in range. context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold); - context.Copy(inRegionLocal, Const(1)); Operand offset = context.BitwiseAnd(context.Subtract(faultAddress, rangeAddress), Const(~PageMask)); // Call the tracking action, with the pointer's relative offset to the base address. Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20)); - context.Call(trackingActionPtr, OperandType.I32, offset, Const(PageSize), isWrite, Const(0)); + + context.Copy(inRegionLocal, Const(0)); + + Operand skipActionLabel = Label(); + + // Tracking action should be non-null to call it, otherwise assume false return. + context.BranchIfFalse(skipActionLabel, trackingActionPtr); + Operand result = context.Call(trackingActionPtr, OperandType.I32, offset, Const(PageSize), isWrite, Const(0)); + context.Copy(inRegionLocal, result); + + context.MarkLabel(skipActionLabel); + + // If the tracking action returns false or does not exist, it might be an invalid access due to a partial overlap on Windows. + if (OperatingSystem.IsWindows()) + { + context.BranchIfTrue(endLabel, inRegionLocal); + + context.Copy(inRegionLocal, WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context)); + } context.Branch(endLabel); diff --git a/ARMeilleure/Signal/TestMethods.cs b/ARMeilleure/Signal/TestMethods.cs new file mode 100644 index 00000000..2d7cef16 --- /dev/null +++ b/ARMeilleure/Signal/TestMethods.cs @@ -0,0 +1,84 @@ +using ARMeilleure.IntermediateRepresentation; +using ARMeilleure.Translation; +using System; + +using static ARMeilleure.IntermediateRepresentation.Operand.Factory; + +namespace ARMeilleure.Signal +{ + public struct NativeWriteLoopState + { + public int Running; + public int Error; + } + + public static class TestMethods + { + public delegate bool DebugPartialUnmap(); + public delegate int DebugThreadLocalMapGetOrReserve(int threadId, int initialState); + public delegate void DebugNativeWriteLoop(IntPtr nativeWriteLoopPtr, IntPtr writePtr); + + public static DebugPartialUnmap GenerateDebugPartialUnmap() + { + EmitterContext context = new EmitterContext(); + + var result = WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context); + + context.Return(result); + + // Compile and return the function. + + ControlFlowGraph cfg = context.GetControlFlowGraph(); + + OperandType[] argTypes = new OperandType[] { OperandType.I64 }; + + return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq).Map<DebugPartialUnmap>(); + } + + public static DebugThreadLocalMapGetOrReserve GenerateDebugThreadLocalMapGetOrReserve(IntPtr structPtr) + { + EmitterContext context = new EmitterContext(); + + var result = WindowsPartialUnmapHandler.EmitThreadLocalMapIntGetOrReserve(context, structPtr, context.LoadArgument(OperandType.I32, 0), context.LoadArgument(OperandType.I32, 1)); + + context.Return(result); + + // Compile and return the function. + + ControlFlowGraph cfg = context.GetControlFlowGraph(); + + OperandType[] argTypes = new OperandType[] { OperandType.I64 }; + + return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq).Map<DebugThreadLocalMapGetOrReserve>(); + } + + public static DebugNativeWriteLoop GenerateDebugNativeWriteLoop() + { + EmitterContext context = new EmitterContext(); + + // Loop a write to the target address until "running" is false. + + Operand structPtr = context.Copy(context.LoadArgument(OperandType.I64, 0)); + Operand writePtr = context.Copy(context.LoadArgument(OperandType.I64, 1)); + + Operand loopLabel = Label(); + context.MarkLabel(loopLabel); + + context.Store(writePtr, Const(12345)); + + Operand running = context.Load(OperandType.I32, structPtr); + + context.BranchIfTrue(loopLabel, running); + + context.Return(); + + // Compile and return the function. + + ControlFlowGraph cfg = context.GetControlFlowGraph(); + + OperandType[] argTypes = new OperandType[] { OperandType.I64 }; + + return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq).Map<DebugNativeWriteLoop>(); + } + } +} diff --git a/ARMeilleure/Signal/WindowsPartialUnmapHandler.cs b/ARMeilleure/Signal/WindowsPartialUnmapHandler.cs new file mode 100644 index 00000000..941e36e5 --- /dev/null +++ b/ARMeilleure/Signal/WindowsPartialUnmapHandler.cs @@ -0,0 +1,186 @@ +using ARMeilleure.IntermediateRepresentation; +using ARMeilleure.Translation; +using Ryujinx.Common.Memory.PartialUnmaps; +using System; + +using static ARMeilleure.IntermediateRepresentation.Operand.Factory; + +namespace ARMeilleure.Signal +{ + /// <summary> + /// Methods to handle signals caused by partial unmaps. See the structs for C# implementations of the methods. + /// </summary> + internal static class WindowsPartialUnmapHandler + { + public static Operand EmitRetryFromAccessViolation(EmitterContext context) + { + IntPtr partialRemapStatePtr = PartialUnmapState.GlobalState; + IntPtr localCountsPtr = IntPtr.Add(partialRemapStatePtr, PartialUnmapState.LocalCountsOffset); + + // Get the lock first. + EmitNativeReaderLockAcquire(context, IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapLockOffset)); + + IntPtr getCurrentThreadId = WindowsSignalHandlerRegistration.GetCurrentThreadIdFunc(); + Operand threadId = context.Call(Const((ulong)getCurrentThreadId), OperandType.I32); + Operand threadIndex = EmitThreadLocalMapIntGetOrReserve(context, localCountsPtr, threadId, Const(0)); + + Operand endLabel = Label(); + Operand retry = context.AllocateLocal(OperandType.I32); + Operand threadIndexValidLabel = Label(); + + context.BranchIfFalse(threadIndexValidLabel, context.ICompareEqual(threadIndex, Const(-1))); + + context.Copy(retry, Const(1)); // Always retry when thread local cannot be allocated. + + context.Branch(endLabel); + + context.MarkLabel(threadIndexValidLabel); + + Operand threadLocalPartialUnmapsPtr = EmitThreadLocalMapIntGetValuePtr(context, localCountsPtr, threadIndex); + Operand threadLocalPartialUnmaps = context.Load(OperandType.I32, threadLocalPartialUnmapsPtr); + Operand partialUnmapsCount = context.Load(OperandType.I32, Const((ulong)IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapsCountOffset))); + + context.Copy(retry, context.ICompareNotEqual(threadLocalPartialUnmaps, partialUnmapsCount)); + + Operand noRetryLabel = Label(); + + context.BranchIfFalse(noRetryLabel, retry); + + // if (retry) { + + context.Store(threadLocalPartialUnmapsPtr, partialUnmapsCount); + + context.Branch(endLabel); + + context.MarkLabel(noRetryLabel); + + // } + + context.MarkLabel(endLabel); + + // Finally, release the lock and return the retry value. + EmitNativeReaderLockRelease(context, IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapLockOffset)); + + return retry; + } + + public static Operand EmitThreadLocalMapIntGetOrReserve(EmitterContext context, IntPtr threadLocalMapPtr, Operand threadId, Operand initialState) + { + Operand idsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.ThreadIdsOffset)); + + Operand i = context.AllocateLocal(OperandType.I32); + + context.Copy(i, Const(0)); + + // (Loop 1) Check all slots for a matching Thread ID (while also trying to allocate) + + Operand endLabel = Label(); + + Operand loopLabel = Label(); + context.MarkLabel(loopLabel); + + Operand offset = context.Multiply(i, Const(sizeof(int))); + Operand idPtr = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset)); + + // Check that this slot has the thread ID. + Operand existingId = context.CompareAndSwap(idPtr, threadId, threadId); + + // If it was already the thread ID, then we just need to return i. + context.BranchIfTrue(endLabel, context.ICompareEqual(existingId, threadId)); + + context.Copy(i, context.Add(i, Const(1))); + + context.BranchIfTrue(loopLabel, context.ICompareLess(i, Const(ThreadLocalMap<int>.MapSize))); + + // (Loop 2) Try take a slot that is 0 with our Thread ID. + + context.Copy(i, Const(0)); // Reset i. + + Operand loop2Label = Label(); + context.MarkLabel(loop2Label); + + Operand offset2 = context.Multiply(i, Const(sizeof(int))); + Operand idPtr2 = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset2)); + + // Try and swap in the thread id on top of 0. + Operand existingId2 = context.CompareAndSwap(idPtr2, Const(0), threadId); + + Operand idNot0Label = Label(); + + // If it was 0, then we need to initialize the struct entry and return i. + context.BranchIfFalse(idNot0Label, context.ICompareEqual(existingId2, Const(0))); + + Operand structsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.StructsOffset)); + Operand structPtr = context.Add(structsPtr, context.SignExtend32(OperandType.I64, offset2)); + context.Store(structPtr, initialState); + + context.Branch(endLabel); + + context.MarkLabel(idNot0Label); + + context.Copy(i, context.Add(i, Const(1))); + + context.BranchIfTrue(loop2Label, context.ICompareLess(i, Const(ThreadLocalMap<int>.MapSize))); + + context.Copy(i, Const(-1)); // Could not place the thread in the list. + + context.MarkLabel(endLabel); + + return context.Copy(i); + } + + private static Operand EmitThreadLocalMapIntGetValuePtr(EmitterContext context, IntPtr threadLocalMapPtr, Operand index) + { + Operand offset = context.Multiply(index, Const(sizeof(int))); + Operand structsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.StructsOffset)); + + return context.Add(structsPtr, context.SignExtend32(OperandType.I64, offset)); + } + + private static void EmitThreadLocalMapIntRelease(EmitterContext context, IntPtr threadLocalMapPtr, Operand threadId, Operand index) + { + Operand offset = context.Multiply(index, Const(sizeof(int))); + Operand idsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.ThreadIdsOffset)); + Operand idPtr = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset)); + + context.CompareAndSwap(idPtr, threadId, Const(0)); + } + + private static void EmitAtomicAddI32(EmitterContext context, Operand ptr, Operand additive) + { + Operand loop = Label(); + context.MarkLabel(loop); + + Operand initial = context.Load(OperandType.I32, ptr); + Operand newValue = context.Add(initial, additive); + + Operand replaced = context.CompareAndSwap(ptr, initial, newValue); + + context.BranchIfFalse(loop, context.ICompareEqual(initial, replaced)); + } + + private static void EmitNativeReaderLockAcquire(EmitterContext context, IntPtr nativeReaderLockPtr) + { + Operand writeLockPtr = Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.WriteLockOffset)); + + // Spin until we can acquire the write lock. + Operand spinLabel = Label(); + context.MarkLabel(spinLabel); + + // Old value must be 0 to continue (we gained the write lock) + context.BranchIfTrue(spinLabel, context.CompareAndSwap(writeLockPtr, Const(0), Const(1))); + + // Increment reader count. + EmitAtomicAddI32(context, Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.ReaderCountOffset)), Const(1)); + + // Release write lock. + context.CompareAndSwap(writeLockPtr, Const(1), Const(0)); + } + + private static void EmitNativeReaderLockRelease(EmitterContext context, IntPtr nativeReaderLockPtr) + { + // Decrement reader count. + EmitAtomicAddI32(context, Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.ReaderCountOffset)), Const(-1)); + } + } +} diff --git a/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs b/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs index 959d1c47..513829a6 100644 --- a/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs +++ b/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs @@ -1,9 +1,10 @@ using System; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; namespace ARMeilleure.Signal { - class WindowsSignalHandlerRegistration + unsafe class WindowsSignalHandlerRegistration { [DllImport("kernel32.dll")] private static extern IntPtr AddVectoredExceptionHandler(uint first, IntPtr handler); @@ -11,6 +12,14 @@ namespace ARMeilleure.Signal [DllImport("kernel32.dll")] private static extern ulong RemoveVectoredExceptionHandler(IntPtr handle); + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Ansi)] + static extern IntPtr LoadLibrary([MarshalAs(UnmanagedType.LPStr)] string lpFileName); + + [DllImport("kernel32.dll", CharSet = CharSet.Ansi, ExactSpelling = true, SetLastError = true)] + private static extern IntPtr GetProcAddress(IntPtr hModule, string procName); + + private static IntPtr _getCurrentThreadIdPtr; + public static IntPtr RegisterExceptionHandler(IntPtr action) { return AddVectoredExceptionHandler(1, action); @@ -20,5 +29,17 @@ namespace ARMeilleure.Signal { return RemoveVectoredExceptionHandler(handle) != 0; } + + public static IntPtr GetCurrentThreadIdFunc() + { + if (_getCurrentThreadIdPtr == IntPtr.Zero) + { + IntPtr handle = LoadLibrary("kernel32.dll"); + + _getCurrentThreadIdPtr = GetProcAddress(handle, "GetCurrentThreadId"); + } + + return _getCurrentThreadIdPtr; + } } } |
