aboutsummaryrefslogtreecommitdiff
path: root/src/ARMeilleure/Signal
diff options
context:
space:
mode:
authorTSR Berry <20988865+TSRBerry@users.noreply.github.com>2023-04-08 01:22:00 +0200
committerMary <thog@protonmail.com>2023-04-27 23:51:14 +0200
commitcee712105850ac3385cd0091a923438167433f9f (patch)
tree4a5274b21d8b7f938c0d0ce18736d3f2993b11b1 /src/ARMeilleure/Signal
parentcd124bda587ef09668a971fa1cac1c3f0cfc9f21 (diff)
Move solution and projects to src
Diffstat (limited to 'src/ARMeilleure/Signal')
-rw-r--r--src/ARMeilleure/Signal/NativeSignalHandler.cs422
-rw-r--r--src/ARMeilleure/Signal/TestMethods.cs84
-rw-r--r--src/ARMeilleure/Signal/UnixSignalHandlerRegistration.cs83
-rw-r--r--src/ARMeilleure/Signal/WindowsPartialUnmapHandler.cs186
-rw-r--r--src/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs44
5 files changed, 819 insertions, 0 deletions
diff --git a/src/ARMeilleure/Signal/NativeSignalHandler.cs b/src/ARMeilleure/Signal/NativeSignalHandler.cs
new file mode 100644
index 00000000..cddeb817
--- /dev/null
+++ b/src/ARMeilleure/Signal/NativeSignalHandler.cs
@@ -0,0 +1,422 @@
+using ARMeilleure.IntermediateRepresentation;
+using ARMeilleure.Memory;
+using ARMeilleure.Translation;
+using ARMeilleure.Translation.Cache;
+using System;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+
+using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
+
+namespace ARMeilleure.Signal
+{
+ [StructLayout(LayoutKind.Sequential, Pack = 1)]
+ struct SignalHandlerRange
+ {
+ public int IsActive;
+ public nuint RangeAddress;
+ public nuint RangeEndAddress;
+ public IntPtr ActionPointer;
+ }
+
+ [StructLayout(LayoutKind.Sequential, Pack = 1)]
+ struct SignalHandlerConfig
+ {
+ /// <summary>
+ /// The byte offset of the faulting address in the SigInfo or ExceptionRecord struct.
+ /// </summary>
+ public int StructAddressOffset;
+
+ /// <summary>
+ /// The byte offset of the write flag in the SigInfo or ExceptionRecord struct.
+ /// </summary>
+ public int StructWriteOffset;
+
+ /// <summary>
+ /// The sigaction handler that was registered before this one. (unix only)
+ /// </summary>
+ public nuint UnixOldSigaction;
+
+ /// <summary>
+ /// The type of the previous sigaction. True for the 3 argument variant. (unix only)
+ /// </summary>
+ public int UnixOldSigaction3Arg;
+
+ public SignalHandlerRange Range0;
+ public SignalHandlerRange Range1;
+ public SignalHandlerRange Range2;
+ public SignalHandlerRange Range3;
+ public SignalHandlerRange Range4;
+ public SignalHandlerRange Range5;
+ public SignalHandlerRange Range6;
+ public SignalHandlerRange Range7;
+ }
+
+ public static class NativeSignalHandler
+ {
+ private delegate void UnixExceptionHandler(int sig, IntPtr info, IntPtr ucontext);
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ private delegate int VectoredExceptionHandler(IntPtr exceptionInfo);
+
+ private const int MaxTrackedRanges = 8;
+
+ private const int StructAddressOffset = 0;
+ private const int StructWriteOffset = 4;
+ private const int UnixOldSigaction = 8;
+ private const int UnixOldSigaction3Arg = 16;
+ private const int RangeOffset = 20;
+
+ private const int EXCEPTION_CONTINUE_SEARCH = 0;
+ private const int EXCEPTION_CONTINUE_EXECUTION = -1;
+
+ private const uint EXCEPTION_ACCESS_VIOLATION = 0xc0000005;
+
+ private static ulong _pageSize;
+ private static ulong _pageMask;
+
+ private static IntPtr _handlerConfig;
+ private static IntPtr _signalHandlerPtr;
+ private static IntPtr _signalHandlerHandle;
+
+ private static readonly object _lock = new object();
+ private static bool _initialized;
+
+ static NativeSignalHandler()
+ {
+ _handlerConfig = Marshal.AllocHGlobal(Unsafe.SizeOf<SignalHandlerConfig>());
+ ref SignalHandlerConfig config = ref GetConfigRef();
+
+ config = new SignalHandlerConfig();
+ }
+
+ public static void Initialize(IJitMemoryAllocator allocator)
+ {
+ JitCache.Initialize(allocator);
+ }
+
+ public static void InitializeSignalHandler(ulong pageSize, Func<IntPtr, IntPtr, IntPtr> customSignalHandlerFactory = null)
+ {
+ if (_initialized) return;
+
+ lock (_lock)
+ {
+ if (_initialized) return;
+
+ _pageSize = pageSize;
+ _pageMask = pageSize - 1;
+
+ ref SignalHandlerConfig config = ref GetConfigRef();
+
+ if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
+ {
+ _signalHandlerPtr = Marshal.GetFunctionPointerForDelegate(GenerateUnixSignalHandler(_handlerConfig));
+
+ if (customSignalHandlerFactory != null)
+ {
+ _signalHandlerPtr = customSignalHandlerFactory(UnixSignalHandlerRegistration.GetSegfaultExceptionHandler().sa_handler, _signalHandlerPtr);
+ }
+
+ var old = UnixSignalHandlerRegistration.RegisterExceptionHandler(_signalHandlerPtr);
+
+ config.UnixOldSigaction = (nuint)(ulong)old.sa_handler;
+ config.UnixOldSigaction3Arg = old.sa_flags & 4;
+ }
+ else
+ {
+ config.StructAddressOffset = 40; // ExceptionInformation1
+ config.StructWriteOffset = 32; // ExceptionInformation0
+
+ _signalHandlerPtr = Marshal.GetFunctionPointerForDelegate(GenerateWindowsSignalHandler(_handlerConfig));
+
+ if (customSignalHandlerFactory != null)
+ {
+ _signalHandlerPtr = customSignalHandlerFactory(IntPtr.Zero, _signalHandlerPtr);
+ }
+
+ _signalHandlerHandle = WindowsSignalHandlerRegistration.RegisterExceptionHandler(_signalHandlerPtr);
+ }
+
+ _initialized = true;
+ }
+ }
+
+ private static unsafe ref SignalHandlerConfig GetConfigRef()
+ {
+ return ref Unsafe.AsRef<SignalHandlerConfig>((void*)_handlerConfig);
+ }
+
+ public static unsafe bool AddTrackedRegion(nuint address, nuint endAddress, IntPtr action)
+ {
+ var ranges = &((SignalHandlerConfig*)_handlerConfig)->Range0;
+
+ for (int i = 0; i < MaxTrackedRanges; i++)
+ {
+ if (ranges[i].IsActive == 0)
+ {
+ ranges[i].RangeAddress = address;
+ ranges[i].RangeEndAddress = endAddress;
+ ranges[i].ActionPointer = action;
+ ranges[i].IsActive = 1;
+
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ public static unsafe bool RemoveTrackedRegion(nuint address)
+ {
+ var ranges = &((SignalHandlerConfig*)_handlerConfig)->Range0;
+
+ for (int i = 0; i < MaxTrackedRanges; i++)
+ {
+ if (ranges[i].IsActive == 1 && ranges[i].RangeAddress == address)
+ {
+ ranges[i].IsActive = 0;
+
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ private static Operand EmitGenericRegionCheck(EmitterContext context, IntPtr signalStructPtr, Operand faultAddress, Operand isWrite)
+ {
+ Operand inRegionLocal = context.AllocateLocal(OperandType.I32);
+ context.Copy(inRegionLocal, Const(0));
+
+ Operand endLabel = Label();
+
+ for (int i = 0; i < MaxTrackedRanges; i++)
+ {
+ ulong rangeBaseOffset = (ulong)(RangeOffset + i * Unsafe.SizeOf<SignalHandlerRange>());
+
+ Operand nextLabel = Label();
+
+ Operand isActive = context.Load(OperandType.I32, Const((ulong)signalStructPtr + rangeBaseOffset));
+
+ context.BranchIfFalse(nextLabel, isActive);
+
+ Operand rangeAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 4));
+ Operand rangeEndAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 12));
+
+ // Is the fault address within this tracked region?
+ Operand inRange = context.BitwiseAnd(
+ context.ICompare(faultAddress, rangeAddress, Comparison.GreaterOrEqualUI),
+ context.ICompare(faultAddress, rangeEndAddress, Comparison.LessUI)
+ );
+
+ // Only call tracking if in range.
+ context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold);
+
+ Operand offset = context.BitwiseAnd(context.Subtract(faultAddress, rangeAddress), Const(~_pageMask));
+
+ // Call the tracking action, with the pointer's relative offset to the base address.
+ Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20));
+
+ context.Copy(inRegionLocal, Const(0));
+
+ Operand skipActionLabel = Label();
+
+ // Tracking action should be non-null to call it, otherwise assume false return.
+ context.BranchIfFalse(skipActionLabel, trackingActionPtr);
+ Operand result = context.Call(trackingActionPtr, OperandType.I32, offset, Const(_pageSize), isWrite);
+ context.Copy(inRegionLocal, result);
+
+ context.MarkLabel(skipActionLabel);
+
+ // If the tracking action returns false or does not exist, it might be an invalid access due to a partial overlap on Windows.
+ if (OperatingSystem.IsWindows())
+ {
+ context.BranchIfTrue(endLabel, inRegionLocal);
+
+ context.Copy(inRegionLocal, WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context));
+ }
+
+ context.Branch(endLabel);
+
+ context.MarkLabel(nextLabel);
+ }
+
+ context.MarkLabel(endLabel);
+
+ return context.Copy(inRegionLocal);
+ }
+
+ private static Operand GenerateUnixFaultAddress(EmitterContext context, Operand sigInfoPtr)
+ {
+ ulong structAddressOffset = OperatingSystem.IsMacOS() ? 24ul : 16ul; // si_addr
+ return context.Load(OperandType.I64, context.Add(sigInfoPtr, Const(structAddressOffset)));
+ }
+
+ private static Operand GenerateUnixWriteFlag(EmitterContext context, Operand ucontextPtr)
+ {
+ if (OperatingSystem.IsMacOS())
+ {
+ const ulong mcontextOffset = 48; // uc_mcontext
+ Operand ctxPtr = context.Load(OperandType.I64, context.Add(ucontextPtr, Const(mcontextOffset)));
+
+ if (RuntimeInformation.ProcessArchitecture == Architecture.Arm64)
+ {
+ const ulong esrOffset = 8; // __es.__esr
+ Operand esr = context.Load(OperandType.I64, context.Add(ctxPtr, Const(esrOffset)));
+ return context.BitwiseAnd(esr, Const(0x40ul));
+ }
+
+ if (RuntimeInformation.ProcessArchitecture == Architecture.X64)
+ {
+ const ulong errOffset = 4; // __es.__err
+ Operand err = context.Load(OperandType.I64, context.Add(ctxPtr, Const(errOffset)));
+ return context.BitwiseAnd(err, Const(2ul));
+ }
+ }
+ else if (OperatingSystem.IsLinux())
+ {
+ if (RuntimeInformation.ProcessArchitecture == Architecture.Arm64)
+ {
+ Operand auxPtr = context.AllocateLocal(OperandType.I64);
+
+ Operand loopLabel = Label();
+ Operand successLabel = Label();
+
+ const ulong auxOffset = 464; // uc_mcontext.__reserved
+ const uint esrMagic = 0x45535201;
+
+ context.Copy(auxPtr, context.Add(ucontextPtr, Const(auxOffset)));
+
+ context.MarkLabel(loopLabel);
+
+ // _aarch64_ctx::magic
+ Operand magic = context.Load(OperandType.I32, auxPtr);
+ // _aarch64_ctx::size
+ Operand size = context.Load(OperandType.I32, context.Add(auxPtr, Const(4ul)));
+
+ context.BranchIf(successLabel, magic, Const(esrMagic), Comparison.Equal);
+
+ context.Copy(auxPtr, context.Add(auxPtr, context.ZeroExtend32(OperandType.I64, size)));
+
+ context.Branch(loopLabel);
+
+ context.MarkLabel(successLabel);
+
+ // esr_context::esr
+ Operand esr = context.Load(OperandType.I64, context.Add(auxPtr, Const(8ul)));
+ return context.BitwiseAnd(esr, Const(0x40ul));
+ }
+
+ if (RuntimeInformation.ProcessArchitecture == Architecture.X64)
+ {
+ const int errOffset = 192; // uc_mcontext.gregs[REG_ERR]
+ Operand err = context.Load(OperandType.I64, context.Add(ucontextPtr, Const(errOffset)));
+ return context.BitwiseAnd(err, Const(2ul));
+ }
+ }
+
+ throw new PlatformNotSupportedException();
+ }
+
+ private static UnixExceptionHandler GenerateUnixSignalHandler(IntPtr signalStructPtr)
+ {
+ EmitterContext context = new EmitterContext();
+
+ // (int sig, SigInfo* sigInfo, void* ucontext)
+ Operand sigInfoPtr = context.LoadArgument(OperandType.I64, 1);
+ Operand ucontextPtr = context.LoadArgument(OperandType.I64, 2);
+
+ Operand faultAddress = GenerateUnixFaultAddress(context, sigInfoPtr);
+ Operand writeFlag = GenerateUnixWriteFlag(context, ucontextPtr);
+
+ Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1.
+
+ Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite);
+
+ Operand endLabel = Label();
+
+ context.BranchIfTrue(endLabel, isInRegion);
+
+ Operand unixOldSigaction = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction));
+ Operand unixOldSigaction3Arg = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction3Arg));
+ Operand threeArgLabel = Label();
+
+ context.BranchIfTrue(threeArgLabel, unixOldSigaction3Arg);
+
+ context.Call(unixOldSigaction, OperandType.None, context.LoadArgument(OperandType.I32, 0));
+ context.Branch(endLabel);
+
+ context.MarkLabel(threeArgLabel);
+
+ context.Call(unixOldSigaction,
+ OperandType.None,
+ context.LoadArgument(OperandType.I32, 0),
+ sigInfoPtr,
+ context.LoadArgument(OperandType.I64, 2)
+ );
+
+ context.MarkLabel(endLabel);
+
+ context.Return();
+
+ ControlFlowGraph cfg = context.GetControlFlowGraph();
+
+ OperandType[] argTypes = new OperandType[] { OperandType.I32, OperandType.I64, OperandType.I64 };
+
+ return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Map<UnixExceptionHandler>();
+ }
+
+ private static VectoredExceptionHandler GenerateWindowsSignalHandler(IntPtr signalStructPtr)
+ {
+ EmitterContext context = new EmitterContext();
+
+ // (ExceptionPointers* exceptionInfo)
+ Operand exceptionInfoPtr = context.LoadArgument(OperandType.I64, 0);
+ Operand exceptionRecordPtr = context.Load(OperandType.I64, exceptionInfoPtr);
+
+ // First thing's first - this catches a number of exceptions, but we only want access violations.
+ Operand validExceptionLabel = Label();
+
+ Operand exceptionCode = context.Load(OperandType.I32, exceptionRecordPtr);
+
+ context.BranchIf(validExceptionLabel, exceptionCode, Const(EXCEPTION_ACCESS_VIOLATION), Comparison.Equal);
+
+ context.Return(Const(EXCEPTION_CONTINUE_SEARCH)); // Don't handle this one.
+
+ context.MarkLabel(validExceptionLabel);
+
+ // Next, read the address of the invalid access, and whether it is a write or not.
+
+ Operand structAddressOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructAddressOffset));
+ Operand structWriteOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructWriteOffset));
+
+ Operand faultAddress = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structAddressOffset)));
+ Operand writeFlag = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structWriteOffset)));
+
+ Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1.
+
+ Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite);
+
+ Operand endLabel = Label();
+
+ // If the region check result is false, then run the next vectored exception handler.
+
+ context.BranchIfTrue(endLabel, isInRegion);
+
+ context.Return(Const(EXCEPTION_CONTINUE_SEARCH));
+
+ context.MarkLabel(endLabel);
+
+ // Otherwise, return to execution.
+
+ context.Return(Const(EXCEPTION_CONTINUE_EXECUTION));
+
+ // Compile and return the function.
+
+ ControlFlowGraph cfg = context.GetControlFlowGraph();
+
+ OperandType[] argTypes = new OperandType[] { OperandType.I64 };
+
+ return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Map<VectoredExceptionHandler>();
+ }
+ }
+}
diff --git a/src/ARMeilleure/Signal/TestMethods.cs b/src/ARMeilleure/Signal/TestMethods.cs
new file mode 100644
index 00000000..e2ecad24
--- /dev/null
+++ b/src/ARMeilleure/Signal/TestMethods.cs
@@ -0,0 +1,84 @@
+using ARMeilleure.IntermediateRepresentation;
+using ARMeilleure.Translation;
+using System;
+using System.Runtime.InteropServices;
+using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
+
+namespace ARMeilleure.Signal
+{
+ public struct NativeWriteLoopState
+ {
+ public int Running;
+ public int Error;
+ }
+
+ public static class TestMethods
+ {
+ public delegate bool DebugPartialUnmap();
+ public delegate int DebugThreadLocalMapGetOrReserve(int threadId, int initialState);
+ public delegate void DebugNativeWriteLoop(IntPtr nativeWriteLoopPtr, IntPtr writePtr);
+
+ public static DebugPartialUnmap GenerateDebugPartialUnmap()
+ {
+ EmitterContext context = new EmitterContext();
+
+ var result = WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context);
+
+ context.Return(result);
+
+ // Compile and return the function.
+
+ ControlFlowGraph cfg = context.GetControlFlowGraph();
+
+ OperandType[] argTypes = new OperandType[] { OperandType.I64 };
+
+ return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Map<DebugPartialUnmap>();
+ }
+
+ public static DebugThreadLocalMapGetOrReserve GenerateDebugThreadLocalMapGetOrReserve(IntPtr structPtr)
+ {
+ EmitterContext context = new EmitterContext();
+
+ var result = WindowsPartialUnmapHandler.EmitThreadLocalMapIntGetOrReserve(context, structPtr, context.LoadArgument(OperandType.I32, 0), context.LoadArgument(OperandType.I32, 1));
+
+ context.Return(result);
+
+ // Compile and return the function.
+
+ ControlFlowGraph cfg = context.GetControlFlowGraph();
+
+ OperandType[] argTypes = new OperandType[] { OperandType.I64 };
+
+ return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Map<DebugThreadLocalMapGetOrReserve>();
+ }
+
+ public static DebugNativeWriteLoop GenerateDebugNativeWriteLoop()
+ {
+ EmitterContext context = new EmitterContext();
+
+ // Loop a write to the target address until "running" is false.
+
+ Operand structPtr = context.Copy(context.LoadArgument(OperandType.I64, 0));
+ Operand writePtr = context.Copy(context.LoadArgument(OperandType.I64, 1));
+
+ Operand loopLabel = Label();
+ context.MarkLabel(loopLabel);
+
+ context.Store(writePtr, Const(12345));
+
+ Operand running = context.Load(OperandType.I32, structPtr);
+
+ context.BranchIfTrue(loopLabel, running);
+
+ context.Return();
+
+ // Compile and return the function.
+
+ ControlFlowGraph cfg = context.GetControlFlowGraph();
+
+ OperandType[] argTypes = new OperandType[] { OperandType.I64 };
+
+ return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Map<DebugNativeWriteLoop>();
+ }
+ }
+}
diff --git a/src/ARMeilleure/Signal/UnixSignalHandlerRegistration.cs b/src/ARMeilleure/Signal/UnixSignalHandlerRegistration.cs
new file mode 100644
index 00000000..22009240
--- /dev/null
+++ b/src/ARMeilleure/Signal/UnixSignalHandlerRegistration.cs
@@ -0,0 +1,83 @@
+using System;
+using System.Runtime.InteropServices;
+
+namespace ARMeilleure.Signal
+{
+ static partial class UnixSignalHandlerRegistration
+ {
+ [StructLayout(LayoutKind.Sequential, Pack = 1)]
+ public unsafe struct SigSet
+ {
+ fixed long sa_mask[16];
+ }
+
+ [StructLayout(LayoutKind.Sequential, Pack = 1)]
+ public struct SigAction
+ {
+ public IntPtr sa_handler;
+ public SigSet sa_mask;
+ public int sa_flags;
+ public IntPtr sa_restorer;
+ }
+
+ private const int SIGSEGV = 11;
+ private const int SIGBUS = 10;
+ private const int SA_SIGINFO = 0x00000004;
+
+ [LibraryImport("libc", SetLastError = true)]
+ private static partial int sigaction(int signum, ref SigAction sigAction, out SigAction oldAction);
+
+ [LibraryImport("libc", SetLastError = true)]
+ private static partial int sigaction(int signum, IntPtr sigAction, out SigAction oldAction);
+
+ [LibraryImport("libc", SetLastError = true)]
+ private static partial int sigemptyset(ref SigSet set);
+
+ public static SigAction GetSegfaultExceptionHandler()
+ {
+ int result = sigaction(SIGSEGV, IntPtr.Zero, out SigAction old);
+
+ if (result != 0)
+ {
+ throw new InvalidOperationException($"Could not get SIGSEGV sigaction. Error: {result}");
+ }
+
+ return old;
+ }
+
+ public static SigAction RegisterExceptionHandler(IntPtr action)
+ {
+ SigAction sig = new SigAction
+ {
+ sa_handler = action,
+ sa_flags = SA_SIGINFO
+ };
+
+ sigemptyset(ref sig.sa_mask);
+
+ int result = sigaction(SIGSEGV, ref sig, out SigAction old);
+
+ if (result != 0)
+ {
+ throw new InvalidOperationException($"Could not register SIGSEGV sigaction. Error: {result}");
+ }
+
+ if (OperatingSystem.IsMacOS())
+ {
+ result = sigaction(SIGBUS, ref sig, out _);
+
+ if (result != 0)
+ {
+ throw new InvalidOperationException($"Could not register SIGBUS sigaction. Error: {result}");
+ }
+ }
+
+ return old;
+ }
+
+ public static bool RestoreExceptionHandler(SigAction oldAction)
+ {
+ return sigaction(SIGSEGV, ref oldAction, out SigAction _) == 0 && (!OperatingSystem.IsMacOS() || sigaction(SIGBUS, ref oldAction, out SigAction _) == 0);
+ }
+ }
+}
diff --git a/src/ARMeilleure/Signal/WindowsPartialUnmapHandler.cs b/src/ARMeilleure/Signal/WindowsPartialUnmapHandler.cs
new file mode 100644
index 00000000..941e36e5
--- /dev/null
+++ b/src/ARMeilleure/Signal/WindowsPartialUnmapHandler.cs
@@ -0,0 +1,186 @@
+using ARMeilleure.IntermediateRepresentation;
+using ARMeilleure.Translation;
+using Ryujinx.Common.Memory.PartialUnmaps;
+using System;
+
+using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
+
+namespace ARMeilleure.Signal
+{
+ /// <summary>
+ /// Methods to handle signals caused by partial unmaps. See the structs for C# implementations of the methods.
+ /// </summary>
+ internal static class WindowsPartialUnmapHandler
+ {
+ public static Operand EmitRetryFromAccessViolation(EmitterContext context)
+ {
+ IntPtr partialRemapStatePtr = PartialUnmapState.GlobalState;
+ IntPtr localCountsPtr = IntPtr.Add(partialRemapStatePtr, PartialUnmapState.LocalCountsOffset);
+
+ // Get the lock first.
+ EmitNativeReaderLockAcquire(context, IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapLockOffset));
+
+ IntPtr getCurrentThreadId = WindowsSignalHandlerRegistration.GetCurrentThreadIdFunc();
+ Operand threadId = context.Call(Const((ulong)getCurrentThreadId), OperandType.I32);
+ Operand threadIndex = EmitThreadLocalMapIntGetOrReserve(context, localCountsPtr, threadId, Const(0));
+
+ Operand endLabel = Label();
+ Operand retry = context.AllocateLocal(OperandType.I32);
+ Operand threadIndexValidLabel = Label();
+
+ context.BranchIfFalse(threadIndexValidLabel, context.ICompareEqual(threadIndex, Const(-1)));
+
+ context.Copy(retry, Const(1)); // Always retry when thread local cannot be allocated.
+
+ context.Branch(endLabel);
+
+ context.MarkLabel(threadIndexValidLabel);
+
+ Operand threadLocalPartialUnmapsPtr = EmitThreadLocalMapIntGetValuePtr(context, localCountsPtr, threadIndex);
+ Operand threadLocalPartialUnmaps = context.Load(OperandType.I32, threadLocalPartialUnmapsPtr);
+ Operand partialUnmapsCount = context.Load(OperandType.I32, Const((ulong)IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapsCountOffset)));
+
+ context.Copy(retry, context.ICompareNotEqual(threadLocalPartialUnmaps, partialUnmapsCount));
+
+ Operand noRetryLabel = Label();
+
+ context.BranchIfFalse(noRetryLabel, retry);
+
+ // if (retry) {
+
+ context.Store(threadLocalPartialUnmapsPtr, partialUnmapsCount);
+
+ context.Branch(endLabel);
+
+ context.MarkLabel(noRetryLabel);
+
+ // }
+
+ context.MarkLabel(endLabel);
+
+ // Finally, release the lock and return the retry value.
+ EmitNativeReaderLockRelease(context, IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapLockOffset));
+
+ return retry;
+ }
+
+ public static Operand EmitThreadLocalMapIntGetOrReserve(EmitterContext context, IntPtr threadLocalMapPtr, Operand threadId, Operand initialState)
+ {
+ Operand idsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.ThreadIdsOffset));
+
+ Operand i = context.AllocateLocal(OperandType.I32);
+
+ context.Copy(i, Const(0));
+
+ // (Loop 1) Check all slots for a matching Thread ID (while also trying to allocate)
+
+ Operand endLabel = Label();
+
+ Operand loopLabel = Label();
+ context.MarkLabel(loopLabel);
+
+ Operand offset = context.Multiply(i, Const(sizeof(int)));
+ Operand idPtr = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset));
+
+ // Check that this slot has the thread ID.
+ Operand existingId = context.CompareAndSwap(idPtr, threadId, threadId);
+
+ // If it was already the thread ID, then we just need to return i.
+ context.BranchIfTrue(endLabel, context.ICompareEqual(existingId, threadId));
+
+ context.Copy(i, context.Add(i, Const(1)));
+
+ context.BranchIfTrue(loopLabel, context.ICompareLess(i, Const(ThreadLocalMap<int>.MapSize)));
+
+ // (Loop 2) Try take a slot that is 0 with our Thread ID.
+
+ context.Copy(i, Const(0)); // Reset i.
+
+ Operand loop2Label = Label();
+ context.MarkLabel(loop2Label);
+
+ Operand offset2 = context.Multiply(i, Const(sizeof(int)));
+ Operand idPtr2 = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset2));
+
+ // Try and swap in the thread id on top of 0.
+ Operand existingId2 = context.CompareAndSwap(idPtr2, Const(0), threadId);
+
+ Operand idNot0Label = Label();
+
+ // If it was 0, then we need to initialize the struct entry and return i.
+ context.BranchIfFalse(idNot0Label, context.ICompareEqual(existingId2, Const(0)));
+
+ Operand structsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.StructsOffset));
+ Operand structPtr = context.Add(structsPtr, context.SignExtend32(OperandType.I64, offset2));
+ context.Store(structPtr, initialState);
+
+ context.Branch(endLabel);
+
+ context.MarkLabel(idNot0Label);
+
+ context.Copy(i, context.Add(i, Const(1)));
+
+ context.BranchIfTrue(loop2Label, context.ICompareLess(i, Const(ThreadLocalMap<int>.MapSize)));
+
+ context.Copy(i, Const(-1)); // Could not place the thread in the list.
+
+ context.MarkLabel(endLabel);
+
+ return context.Copy(i);
+ }
+
+ private static Operand EmitThreadLocalMapIntGetValuePtr(EmitterContext context, IntPtr threadLocalMapPtr, Operand index)
+ {
+ Operand offset = context.Multiply(index, Const(sizeof(int)));
+ Operand structsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.StructsOffset));
+
+ return context.Add(structsPtr, context.SignExtend32(OperandType.I64, offset));
+ }
+
+ private static void EmitThreadLocalMapIntRelease(EmitterContext context, IntPtr threadLocalMapPtr, Operand threadId, Operand index)
+ {
+ Operand offset = context.Multiply(index, Const(sizeof(int)));
+ Operand idsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.ThreadIdsOffset));
+ Operand idPtr = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset));
+
+ context.CompareAndSwap(idPtr, threadId, Const(0));
+ }
+
+ private static void EmitAtomicAddI32(EmitterContext context, Operand ptr, Operand additive)
+ {
+ Operand loop = Label();
+ context.MarkLabel(loop);
+
+ Operand initial = context.Load(OperandType.I32, ptr);
+ Operand newValue = context.Add(initial, additive);
+
+ Operand replaced = context.CompareAndSwap(ptr, initial, newValue);
+
+ context.BranchIfFalse(loop, context.ICompareEqual(initial, replaced));
+ }
+
+ private static void EmitNativeReaderLockAcquire(EmitterContext context, IntPtr nativeReaderLockPtr)
+ {
+ Operand writeLockPtr = Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.WriteLockOffset));
+
+ // Spin until we can acquire the write lock.
+ Operand spinLabel = Label();
+ context.MarkLabel(spinLabel);
+
+ // Old value must be 0 to continue (we gained the write lock)
+ context.BranchIfTrue(spinLabel, context.CompareAndSwap(writeLockPtr, Const(0), Const(1)));
+
+ // Increment reader count.
+ EmitAtomicAddI32(context, Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.ReaderCountOffset)), Const(1));
+
+ // Release write lock.
+ context.CompareAndSwap(writeLockPtr, Const(1), Const(0));
+ }
+
+ private static void EmitNativeReaderLockRelease(EmitterContext context, IntPtr nativeReaderLockPtr)
+ {
+ // Decrement reader count.
+ EmitAtomicAddI32(context, Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.ReaderCountOffset)), Const(-1));
+ }
+ }
+}
diff --git a/src/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs b/src/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs
new file mode 100644
index 00000000..3219e015
--- /dev/null
+++ b/src/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs
@@ -0,0 +1,44 @@
+using System;
+using System.Runtime.InteropServices;
+
+namespace ARMeilleure.Signal
+{
+ unsafe partial class WindowsSignalHandlerRegistration
+ {
+ [LibraryImport("kernel32.dll")]
+ private static partial IntPtr AddVectoredExceptionHandler(uint first, IntPtr handler);
+
+ [LibraryImport("kernel32.dll")]
+ private static partial ulong RemoveVectoredExceptionHandler(IntPtr handle);
+
+ [LibraryImport("kernel32.dll", SetLastError = true, EntryPoint = "LoadLibraryA")]
+ private static partial IntPtr LoadLibrary([MarshalAs(UnmanagedType.LPStr)] string lpFileName);
+
+ [LibraryImport("kernel32.dll", SetLastError = true)]
+ private static partial IntPtr GetProcAddress(IntPtr hModule, [MarshalAs(UnmanagedType.LPStr)] string procName);
+
+ private static IntPtr _getCurrentThreadIdPtr;
+
+ public static IntPtr RegisterExceptionHandler(IntPtr action)
+ {
+ return AddVectoredExceptionHandler(1, action);
+ }
+
+ public static bool RemoveExceptionHandler(IntPtr handle)
+ {
+ return RemoveVectoredExceptionHandler(handle) != 0;
+ }
+
+ public static IntPtr GetCurrentThreadIdFunc()
+ {
+ if (_getCurrentThreadIdPtr == IntPtr.Zero)
+ {
+ IntPtr handle = LoadLibrary("kernel32.dll");
+
+ _getCurrentThreadIdPtr = GetProcAddress(handle, "GetCurrentThreadId");
+ }
+
+ return _getCurrentThreadIdPtr;
+ }
+ }
+}