aboutsummaryrefslogtreecommitdiff
path: root/ARMeilleure/Signal
diff options
context:
space:
mode:
Diffstat (limited to 'ARMeilleure/Signal')
-rw-r--r--ARMeilleure/Signal/NativeSignalHandler.cs21
-rw-r--r--ARMeilleure/Signal/TestMethods.cs84
-rw-r--r--ARMeilleure/Signal/WindowsPartialUnmapHandler.cs186
-rw-r--r--ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs23
4 files changed, 311 insertions, 3 deletions
diff --git a/ARMeilleure/Signal/NativeSignalHandler.cs b/ARMeilleure/Signal/NativeSignalHandler.cs
index cad0d420..0257f440 100644
--- a/ARMeilleure/Signal/NativeSignalHandler.cs
+++ b/ARMeilleure/Signal/NativeSignalHandler.cs
@@ -197,12 +197,29 @@ namespace ARMeilleure.Signal
// Only call tracking if in range.
context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold);
- context.Copy(inRegionLocal, Const(1));
Operand offset = context.BitwiseAnd(context.Subtract(faultAddress, rangeAddress), Const(~PageMask));
// Call the tracking action, with the pointer's relative offset to the base address.
Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20));
- context.Call(trackingActionPtr, OperandType.I32, offset, Const(PageSize), isWrite, Const(0));
+
+ context.Copy(inRegionLocal, Const(0));
+
+ Operand skipActionLabel = Label();
+
+ // Tracking action should be non-null to call it, otherwise assume false return.
+ context.BranchIfFalse(skipActionLabel, trackingActionPtr);
+ Operand result = context.Call(trackingActionPtr, OperandType.I32, offset, Const(PageSize), isWrite, Const(0));
+ context.Copy(inRegionLocal, result);
+
+ context.MarkLabel(skipActionLabel);
+
+ // If the tracking action returns false or does not exist, it might be an invalid access due to a partial overlap on Windows.
+ if (OperatingSystem.IsWindows())
+ {
+ context.BranchIfTrue(endLabel, inRegionLocal);
+
+ context.Copy(inRegionLocal, WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context));
+ }
context.Branch(endLabel);
diff --git a/ARMeilleure/Signal/TestMethods.cs b/ARMeilleure/Signal/TestMethods.cs
new file mode 100644
index 00000000..2d7cef16
--- /dev/null
+++ b/ARMeilleure/Signal/TestMethods.cs
@@ -0,0 +1,84 @@
+using ARMeilleure.IntermediateRepresentation;
+using ARMeilleure.Translation;
+using System;
+
+using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
+
+namespace ARMeilleure.Signal
+{
+ public struct NativeWriteLoopState
+ {
+ public int Running;
+ public int Error;
+ }
+
+ public static class TestMethods
+ {
+ public delegate bool DebugPartialUnmap();
+ public delegate int DebugThreadLocalMapGetOrReserve(int threadId, int initialState);
+ public delegate void DebugNativeWriteLoop(IntPtr nativeWriteLoopPtr, IntPtr writePtr);
+
+ public static DebugPartialUnmap GenerateDebugPartialUnmap()
+ {
+ EmitterContext context = new EmitterContext();
+
+ var result = WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context);
+
+ context.Return(result);
+
+ // Compile and return the function.
+
+ ControlFlowGraph cfg = context.GetControlFlowGraph();
+
+ OperandType[] argTypes = new OperandType[] { OperandType.I64 };
+
+ return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq).Map<DebugPartialUnmap>();
+ }
+
+ public static DebugThreadLocalMapGetOrReserve GenerateDebugThreadLocalMapGetOrReserve(IntPtr structPtr)
+ {
+ EmitterContext context = new EmitterContext();
+
+ var result = WindowsPartialUnmapHandler.EmitThreadLocalMapIntGetOrReserve(context, structPtr, context.LoadArgument(OperandType.I32, 0), context.LoadArgument(OperandType.I32, 1));
+
+ context.Return(result);
+
+ // Compile and return the function.
+
+ ControlFlowGraph cfg = context.GetControlFlowGraph();
+
+ OperandType[] argTypes = new OperandType[] { OperandType.I64 };
+
+ return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq).Map<DebugThreadLocalMapGetOrReserve>();
+ }
+
+ public static DebugNativeWriteLoop GenerateDebugNativeWriteLoop()
+ {
+ EmitterContext context = new EmitterContext();
+
+ // Loop a write to the target address until "running" is false.
+
+ Operand structPtr = context.Copy(context.LoadArgument(OperandType.I64, 0));
+ Operand writePtr = context.Copy(context.LoadArgument(OperandType.I64, 1));
+
+ Operand loopLabel = Label();
+ context.MarkLabel(loopLabel);
+
+ context.Store(writePtr, Const(12345));
+
+ Operand running = context.Load(OperandType.I32, structPtr);
+
+ context.BranchIfTrue(loopLabel, running);
+
+ context.Return();
+
+ // Compile and return the function.
+
+ ControlFlowGraph cfg = context.GetControlFlowGraph();
+
+ OperandType[] argTypes = new OperandType[] { OperandType.I64 };
+
+ return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq).Map<DebugNativeWriteLoop>();
+ }
+ }
+}
diff --git a/ARMeilleure/Signal/WindowsPartialUnmapHandler.cs b/ARMeilleure/Signal/WindowsPartialUnmapHandler.cs
new file mode 100644
index 00000000..941e36e5
--- /dev/null
+++ b/ARMeilleure/Signal/WindowsPartialUnmapHandler.cs
@@ -0,0 +1,186 @@
+using ARMeilleure.IntermediateRepresentation;
+using ARMeilleure.Translation;
+using Ryujinx.Common.Memory.PartialUnmaps;
+using System;
+
+using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
+
+namespace ARMeilleure.Signal
+{
+ /// <summary>
+ /// Methods to handle signals caused by partial unmaps. See the structs for C# implementations of the methods.
+ /// </summary>
+ internal static class WindowsPartialUnmapHandler
+ {
+ public static Operand EmitRetryFromAccessViolation(EmitterContext context)
+ {
+ IntPtr partialRemapStatePtr = PartialUnmapState.GlobalState;
+ IntPtr localCountsPtr = IntPtr.Add(partialRemapStatePtr, PartialUnmapState.LocalCountsOffset);
+
+ // Get the lock first.
+ EmitNativeReaderLockAcquire(context, IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapLockOffset));
+
+ IntPtr getCurrentThreadId = WindowsSignalHandlerRegistration.GetCurrentThreadIdFunc();
+ Operand threadId = context.Call(Const((ulong)getCurrentThreadId), OperandType.I32);
+ Operand threadIndex = EmitThreadLocalMapIntGetOrReserve(context, localCountsPtr, threadId, Const(0));
+
+ Operand endLabel = Label();
+ Operand retry = context.AllocateLocal(OperandType.I32);
+ Operand threadIndexValidLabel = Label();
+
+ context.BranchIfFalse(threadIndexValidLabel, context.ICompareEqual(threadIndex, Const(-1)));
+
+ context.Copy(retry, Const(1)); // Always retry when thread local cannot be allocated.
+
+ context.Branch(endLabel);
+
+ context.MarkLabel(threadIndexValidLabel);
+
+ Operand threadLocalPartialUnmapsPtr = EmitThreadLocalMapIntGetValuePtr(context, localCountsPtr, threadIndex);
+ Operand threadLocalPartialUnmaps = context.Load(OperandType.I32, threadLocalPartialUnmapsPtr);
+ Operand partialUnmapsCount = context.Load(OperandType.I32, Const((ulong)IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapsCountOffset)));
+
+ context.Copy(retry, context.ICompareNotEqual(threadLocalPartialUnmaps, partialUnmapsCount));
+
+ Operand noRetryLabel = Label();
+
+ context.BranchIfFalse(noRetryLabel, retry);
+
+ // if (retry) {
+
+ context.Store(threadLocalPartialUnmapsPtr, partialUnmapsCount);
+
+ context.Branch(endLabel);
+
+ context.MarkLabel(noRetryLabel);
+
+ // }
+
+ context.MarkLabel(endLabel);
+
+ // Finally, release the lock and return the retry value.
+ EmitNativeReaderLockRelease(context, IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapLockOffset));
+
+ return retry;
+ }
+
+ public static Operand EmitThreadLocalMapIntGetOrReserve(EmitterContext context, IntPtr threadLocalMapPtr, Operand threadId, Operand initialState)
+ {
+ Operand idsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.ThreadIdsOffset));
+
+ Operand i = context.AllocateLocal(OperandType.I32);
+
+ context.Copy(i, Const(0));
+
+ // (Loop 1) Check all slots for a matching Thread ID (while also trying to allocate)
+
+ Operand endLabel = Label();
+
+ Operand loopLabel = Label();
+ context.MarkLabel(loopLabel);
+
+ Operand offset = context.Multiply(i, Const(sizeof(int)));
+ Operand idPtr = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset));
+
+ // Check that this slot has the thread ID.
+ Operand existingId = context.CompareAndSwap(idPtr, threadId, threadId);
+
+ // If it was already the thread ID, then we just need to return i.
+ context.BranchIfTrue(endLabel, context.ICompareEqual(existingId, threadId));
+
+ context.Copy(i, context.Add(i, Const(1)));
+
+ context.BranchIfTrue(loopLabel, context.ICompareLess(i, Const(ThreadLocalMap<int>.MapSize)));
+
+ // (Loop 2) Try take a slot that is 0 with our Thread ID.
+
+ context.Copy(i, Const(0)); // Reset i.
+
+ Operand loop2Label = Label();
+ context.MarkLabel(loop2Label);
+
+ Operand offset2 = context.Multiply(i, Const(sizeof(int)));
+ Operand idPtr2 = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset2));
+
+ // Try and swap in the thread id on top of 0.
+ Operand existingId2 = context.CompareAndSwap(idPtr2, Const(0), threadId);
+
+ Operand idNot0Label = Label();
+
+ // If it was 0, then we need to initialize the struct entry and return i.
+ context.BranchIfFalse(idNot0Label, context.ICompareEqual(existingId2, Const(0)));
+
+ Operand structsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.StructsOffset));
+ Operand structPtr = context.Add(structsPtr, context.SignExtend32(OperandType.I64, offset2));
+ context.Store(structPtr, initialState);
+
+ context.Branch(endLabel);
+
+ context.MarkLabel(idNot0Label);
+
+ context.Copy(i, context.Add(i, Const(1)));
+
+ context.BranchIfTrue(loop2Label, context.ICompareLess(i, Const(ThreadLocalMap<int>.MapSize)));
+
+ context.Copy(i, Const(-1)); // Could not place the thread in the list.
+
+ context.MarkLabel(endLabel);
+
+ return context.Copy(i);
+ }
+
+ private static Operand EmitThreadLocalMapIntGetValuePtr(EmitterContext context, IntPtr threadLocalMapPtr, Operand index)
+ {
+ Operand offset = context.Multiply(index, Const(sizeof(int)));
+ Operand structsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.StructsOffset));
+
+ return context.Add(structsPtr, context.SignExtend32(OperandType.I64, offset));
+ }
+
+ private static void EmitThreadLocalMapIntRelease(EmitterContext context, IntPtr threadLocalMapPtr, Operand threadId, Operand index)
+ {
+ Operand offset = context.Multiply(index, Const(sizeof(int)));
+ Operand idsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.ThreadIdsOffset));
+ Operand idPtr = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset));
+
+ context.CompareAndSwap(idPtr, threadId, Const(0));
+ }
+
+ private static void EmitAtomicAddI32(EmitterContext context, Operand ptr, Operand additive)
+ {
+ Operand loop = Label();
+ context.MarkLabel(loop);
+
+ Operand initial = context.Load(OperandType.I32, ptr);
+ Operand newValue = context.Add(initial, additive);
+
+ Operand replaced = context.CompareAndSwap(ptr, initial, newValue);
+
+ context.BranchIfFalse(loop, context.ICompareEqual(initial, replaced));
+ }
+
+ private static void EmitNativeReaderLockAcquire(EmitterContext context, IntPtr nativeReaderLockPtr)
+ {
+ Operand writeLockPtr = Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.WriteLockOffset));
+
+ // Spin until we can acquire the write lock.
+ Operand spinLabel = Label();
+ context.MarkLabel(spinLabel);
+
+ // Old value must be 0 to continue (we gained the write lock)
+ context.BranchIfTrue(spinLabel, context.CompareAndSwap(writeLockPtr, Const(0), Const(1)));
+
+ // Increment reader count.
+ EmitAtomicAddI32(context, Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.ReaderCountOffset)), Const(1));
+
+ // Release write lock.
+ context.CompareAndSwap(writeLockPtr, Const(1), Const(0));
+ }
+
+ private static void EmitNativeReaderLockRelease(EmitterContext context, IntPtr nativeReaderLockPtr)
+ {
+ // Decrement reader count.
+ EmitAtomicAddI32(context, Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.ReaderCountOffset)), Const(-1));
+ }
+ }
+}
diff --git a/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs b/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs
index 959d1c47..513829a6 100644
--- a/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs
+++ b/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs
@@ -1,9 +1,10 @@
using System;
+using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
namespace ARMeilleure.Signal
{
- class WindowsSignalHandlerRegistration
+ unsafe class WindowsSignalHandlerRegistration
{
[DllImport("kernel32.dll")]
private static extern IntPtr AddVectoredExceptionHandler(uint first, IntPtr handler);
@@ -11,6 +12,14 @@ namespace ARMeilleure.Signal
[DllImport("kernel32.dll")]
private static extern ulong RemoveVectoredExceptionHandler(IntPtr handle);
+ [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Ansi)]
+ static extern IntPtr LoadLibrary([MarshalAs(UnmanagedType.LPStr)] string lpFileName);
+
+ [DllImport("kernel32.dll", CharSet = CharSet.Ansi, ExactSpelling = true, SetLastError = true)]
+ private static extern IntPtr GetProcAddress(IntPtr hModule, string procName);
+
+ private static IntPtr _getCurrentThreadIdPtr;
+
public static IntPtr RegisterExceptionHandler(IntPtr action)
{
return AddVectoredExceptionHandler(1, action);
@@ -20,5 +29,17 @@ namespace ARMeilleure.Signal
{
return RemoveVectoredExceptionHandler(handle) != 0;
}
+
+ public static IntPtr GetCurrentThreadIdFunc()
+ {
+ if (_getCurrentThreadIdPtr == IntPtr.Zero)
+ {
+ IntPtr handle = LoadLibrary("kernel32.dll");
+
+ _getCurrentThreadIdPtr = GetProcAddress(handle, "GetCurrentThreadId");
+ }
+
+ return _getCurrentThreadIdPtr;
+ }
}
}