aboutsummaryrefslogtreecommitdiff
path: root/ARMeilleure/CodeGen/RegisterAllocators/LinearScanAllocator.cs
diff options
context:
space:
mode:
authorFICTURE7 <FICTURE7@gmail.com>2021-10-09 01:15:44 +0400
committerGitHub <noreply@github.com>2021-10-08 18:15:44 -0300
commit69093cf2d69490862aff974f170cee63a0016fd0 (patch)
tree24507a2d3da862416d3c2d3ca228c89cb40d5437 /ARMeilleure/CodeGen/RegisterAllocators/LinearScanAllocator.cs
parentc54a14d0b8d445d9d0074861dca816cc801e4008 (diff)
Optimize LSRA (#2563)
* Optimize `TryAllocateRegWithtoutSpill` a bit * Add a fast path for when all registers are live. * Do not query `GetOverlapPosition` if the register is already in use (i.e: free position is 0). * Do not allocate child split list if not parent * Turn `LiveRange` into a reference struct `LiveRange` is now a reference wrapping struct like `Operand` and `Operation`. It has also been changed into a singly linked-list. In micro-benchmarks traversing the linked-list was faster than binary search on `List<T>`. Even for quite large input sizes (e.g: 1,000,000), surprisingly. Could be because the code gen for traversing the linked-list is much much cleaner and there is no virtual dispatch happening when checking if intervals overlaps. * Turn `LiveInterval` into an iterator The LSRA allocates in forward order and never inspect previous `LiveInterval` once they are expired. Something similar can be done for the `LiveRange`s within the `LiveInterval`s themselves. The `LiveInterval` is turned into a iterator which expires `LiveRange` within it. The iterator is moved forward along with interval walking code, i.e: AllocateInterval(context, interval, cIndex). * Remove `LinearScanAllocator.Sources` Local methods are less susceptible to do allocations than lambdas. * Optimize `GetOverlapPosition(interval)` a bit Time complexity should be in O(n+m) instead of O(nm) now. * Optimize `NumberLocals` a bit Use the same idea as in `HybridAllocator` to store the visited state in the MSB of the Operand's value instead of using a `HashSet<T>`. * Optimize `InsertSplitCopies` a bit Avoid allocating a redundant `CopyResolver`. * Optimize `InsertSplitCopiesAtEdges` a bit Avoid redundant allocations of `CopyResolver`. * Use stack allocation for `freePositions` Avoid redundant computations. * Add `UseList` Replace `SortedIntegerList` with an even more specialized data structure. It allocates memory on the arena allocators and does not require copying use positions when splitting it. * Turn `LiveInterval` into a reference struct `LiveInterval` is now a reference wrapping struct like `Operand` and `Operation`. The rationale behind turning this in a reference wrapping struct is because a `LiveInterval` is associated with each local variable, and these intervals may themselves be split further. I've seen translations having up to 8000 local variables. To make the `LiveInterval` unmanaged, a new data structure called `LiveIntervalList` was added to store child splits. This differs from `SortedList<,>` because it can contain intervals with the same start position. Really wished we got some more of C++ template in C#. :^( * Optimize `GetChildSplit` a bit No need to inspect the remaining ranges if we've reached a range which starts after position, since the split list is ordered. * Optimize `CopyResolver` a bit Lazily allocate the fill, spill and parallel copy structures since most of the time only one of them is needed. * Optimize `BitMap.Enumerator` a bit Marking `MoveNext` as `AggressiveInlining` allows RyuJIT to promote the `Enumerator` struct into registers completely, reducing load/store code a lot since it does not have to store the struct on the stack for ABI purposes. * Use stack allocation for `use/blockedPositions` * Optimize `AllocateWithSpill` a bit * Address feedback * Make `LiveInterval.AddRange(,)` more conservative Produces no diff against master, but just for good measure.
Diffstat (limited to 'ARMeilleure/CodeGen/RegisterAllocators/LinearScanAllocator.cs')
-rw-r--r--ARMeilleure/CodeGen/RegisterAllocators/LinearScanAllocator.cs382
1 files changed, 223 insertions, 159 deletions
diff --git a/ARMeilleure/CodeGen/RegisterAllocators/LinearScanAllocator.cs b/ARMeilleure/CodeGen/RegisterAllocators/LinearScanAllocator.cs
index fd1420a2..d8a40365 100644
--- a/ARMeilleure/CodeGen/RegisterAllocators/LinearScanAllocator.cs
+++ b/ARMeilleure/CodeGen/RegisterAllocators/LinearScanAllocator.cs
@@ -20,17 +20,13 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
private const int RegistersCount = 16;
private HashSet<int> _blockEdges;
-
private LiveRange[] _blockRanges;
-
private BitMap[] _blockLiveIn;
private List<LiveInterval> _intervals;
-
private LiveInterval[] _parentIntervals;
private List<(IntrusiveList<Operation>, Operation)> _operationNodes;
-
private int _operationsCount;
private class AllocationContext
@@ -45,6 +41,11 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
public int IntUsedRegisters { get; set; }
public int VecUsedRegisters { get; set; }
+ private readonly int[] _intFreePositions;
+ private readonly int[] _vecFreePositions;
+ private readonly int _intFreePositionsCount;
+ private readonly int _vecFreePositionsCount;
+
public AllocationContext(StackAllocator stackAlloc, RegisterMasks masks, int intervalsCount)
{
StackAlloc = stackAlloc;
@@ -52,6 +53,43 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
Active = new BitMap(Allocators.Default, intervalsCount);
Inactive = new BitMap(Allocators.Default, intervalsCount);
+
+ PopulateFreePositions(RegisterType.Integer, out _intFreePositions, out _intFreePositionsCount);
+ PopulateFreePositions(RegisterType.Vector, out _vecFreePositions, out _vecFreePositionsCount);
+
+ void PopulateFreePositions(RegisterType type, out int[] positions, out int count)
+ {
+ positions = new int[RegistersCount];
+ count = BitOperations.PopCount((uint)masks.GetAvailableRegisters(type));
+
+ int mask = masks.GetAvailableRegisters(type);
+
+ for (int i = 0; i < positions.Length; i++)
+ {
+ if ((mask & (1 << i)) != 0)
+ {
+ positions[i] = int.MaxValue;
+ }
+ }
+ }
+ }
+
+ public void GetFreePositions(RegisterType type, in Span<int> positions, out int count)
+ {
+ if (type == RegisterType.Integer)
+ {
+ _intFreePositions.CopyTo(positions);
+
+ count = _intFreePositionsCount;
+ }
+ else
+ {
+ Debug.Assert(type == RegisterType.Vector);
+
+ _vecFreePositions.CopyTo(positions);
+
+ count = _vecFreePositionsCount;
+ }
}
public void MoveActiveToInactive(int bit)
@@ -132,6 +170,8 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
{
LiveInterval interval = _intervals[iIndex];
+ interval.Forward(current.GetStart());
+
if (interval.GetEnd() < current.GetStart())
{
context.Active.Clear(iIndex);
@@ -147,6 +187,8 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
{
LiveInterval interval = _intervals[iIndex];
+ interval.Forward(current.GetStart());
+
if (interval.GetEnd() < current.GetStart())
{
context.Inactive.Clear(iIndex);
@@ -167,45 +209,48 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
{
RegisterType regType = current.Local.Type.ToRegisterType();
- int availableRegisters = context.Masks.GetAvailableRegisters(regType);
-
- int[] freePositions = new int[RegistersCount];
+ Span<int> freePositions = stackalloc int[RegistersCount];
- for (int index = 0; index < RegistersCount; index++)
- {
- if ((availableRegisters & (1 << index)) != 0)
- {
- freePositions[index] = int.MaxValue;
- }
- }
+ context.GetFreePositions(regType, freePositions, out int freePositionsCount);
foreach (int iIndex in context.Active)
{
LiveInterval interval = _intervals[iIndex];
+ Register reg = interval.Register;
- if (interval.Register.Type == regType)
+ if (reg.Type == regType)
{
- freePositions[interval.Register.Index] = 0;
+ freePositions[reg.Index] = 0;
+ freePositionsCount--;
}
}
+ // If all registers are already active, return early. No point in inspecting the inactive set to look for
+ // holes.
+ if (freePositionsCount == 0)
+ {
+ return false;
+ }
+
foreach (int iIndex in context.Inactive)
{
LiveInterval interval = _intervals[iIndex];
+ Register reg = interval.Register;
+
+ ref int freePosition = ref freePositions[reg.Index];
- if (interval.Register.Type == regType)
+ if (reg.Type == regType && freePosition != 0)
{
int overlapPosition = interval.GetOverlapPosition(current);
- if (overlapPosition != LiveInterval.NotFound && freePositions[interval.Register.Index] > overlapPosition)
+ if (overlapPosition != LiveInterval.NotFound && freePosition > overlapPosition)
{
- freePositions[interval.Register.Index] = overlapPosition;
+ freePosition = overlapPosition;
}
}
}
int selectedReg = GetHighestValueIndex(freePositions);
-
int selectedNextUse = freePositions[selectedReg];
// Intervals starts and ends at odd positions, unless they span an entire
@@ -227,8 +272,6 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
}
else if (selectedNextUse < current.GetEnd())
{
- Debug.Assert(selectedNextUse > current.GetStart(), "Trying to split interval at the start.");
-
LiveInterval splitChild = current.Split(selectedNextUse);
if (splitChild.UsesCount != 0)
@@ -263,90 +306,72 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
{
RegisterType regType = current.Local.Type.ToRegisterType();
- int availableRegisters = context.Masks.GetAvailableRegisters(regType);
+ Span<int> usePositions = stackalloc int[RegistersCount];
+ Span<int> blockedPositions = stackalloc int[RegistersCount];
- int[] usePositions = new int[RegistersCount];
- int[] blockedPositions = new int[RegistersCount];
-
- for (int index = 0; index < RegistersCount; index++)
- {
- if ((availableRegisters & (1 << index)) != 0)
- {
- usePositions[index] = int.MaxValue;
-
- blockedPositions[index] = int.MaxValue;
- }
- }
-
- void SetUsePosition(int index, int position)
- {
- usePositions[index] = Math.Min(usePositions[index], position);
- }
-
- void SetBlockedPosition(int index, int position)
- {
- blockedPositions[index] = Math.Min(blockedPositions[index], position);
-
- SetUsePosition(index, position);
- }
+ context.GetFreePositions(regType, usePositions, out _);
+ context.GetFreePositions(regType, blockedPositions, out _);
foreach (int iIndex in context.Active)
{
LiveInterval interval = _intervals[iIndex];
+ Register reg = interval.Register;
- if (!interval.IsFixed && interval.Register.Type == regType)
+ if (reg.Type == regType)
{
- int nextUse = interval.NextUseAfter(current.GetStart());
+ ref int usePosition = ref usePositions[reg.Index];
+ ref int blockedPosition = ref blockedPositions[reg.Index];
- if (nextUse != -1)
+ if (interval.IsFixed)
{
- SetUsePosition(interval.Register.Index, nextUse);
+ usePosition = 0;
+ blockedPosition = 0;
}
- }
- }
-
- foreach (int iIndex in context.Inactive)
- {
- LiveInterval interval = _intervals[iIndex];
-
- if (!interval.IsFixed && interval.Register.Type == regType && interval.Overlaps(current))
- {
- int nextUse = interval.NextUseAfter(current.GetStart());
-
- if (nextUse != -1)
+ else
{
- SetUsePosition(interval.Register.Index, nextUse);
- }
- }
- }
-
- foreach (int iIndex in context.Active)
- {
- LiveInterval interval = _intervals[iIndex];
+ int nextUse = interval.NextUseAfter(current.GetStart());
- if (interval.IsFixed && interval.Register.Type == regType)
- {
- SetBlockedPosition(interval.Register.Index, 0);
+ if (nextUse != LiveInterval.NotFound && usePosition > nextUse)
+ {
+ usePosition = nextUse;
+ }
+ }
}
}
foreach (int iIndex in context.Inactive)
{
LiveInterval interval = _intervals[iIndex];
+ Register reg = interval.Register;
- if (interval.IsFixed && interval.Register.Type == regType)
+ if (reg.Type == regType)
{
- int overlapPosition = interval.GetOverlapPosition(current);
+ ref int usePosition = ref usePositions[reg.Index];
+ ref int blockedPosition = ref blockedPositions[reg.Index];
- if (overlapPosition != LiveInterval.NotFound)
+ if (interval.IsFixed)
+ {
+ int overlapPosition = interval.GetOverlapPosition(current);
+
+ if (overlapPosition != LiveInterval.NotFound)
+ {
+ blockedPosition = Math.Min(blockedPosition, overlapPosition);
+ usePosition = Math.Min(usePosition, overlapPosition);
+ }
+ }
+ else if (interval.Overlaps(current))
{
- SetBlockedPosition(interval.Register.Index, overlapPosition);
+ int nextUse = interval.NextUseAfter(current.GetStart());
+
+ if (nextUse != LiveInterval.NotFound && usePosition > nextUse)
+ {
+ usePosition = nextUse;
+ }
}
}
}
int selectedReg = GetHighestValueIndex(usePositions);
-
int currentFirstUse = current.FirstUse();
Debug.Assert(currentFirstUse >= 0, "Current interval has no uses.");
@@ -405,24 +430,24 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
}
}
- private static int GetHighestValueIndex(int[] array)
+ private static int GetHighestValueIndex(Span<int> span)
{
- int higuest = array[0];
+ int highest = span[0];
- if (higuest == int.MaxValue)
+ if (highest == int.MaxValue)
{
return 0;
}
int selected = 0;
- for (int index = 1; index < array.Length; index++)
+ for (int index = 1; index < span.Length; index++)
{
- int current = array[index];
+ int current = span[index];
- if (higuest < current)
+ if (highest < current)
{
- higuest = current;
+ highest = current;
selected = index;
if (current == int.MaxValue)
@@ -543,21 +568,21 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
CopyResolver GetCopyResolver(int position)
{
- CopyResolver copyResolver = new CopyResolver();
-
- if (copyResolvers.TryAdd(position, copyResolver))
+ if (!copyResolvers.TryGetValue(position, out CopyResolver copyResolver))
{
- return copyResolver;
+ copyResolver = new CopyResolver();
+
+ copyResolvers.Add(position, copyResolver);
}
- return copyResolvers[position];
+ return copyResolver;
}
foreach (LiveInterval interval in _intervals.Where(x => x.IsSplit))
{
LiveInterval previous = interval;
- foreach (LiveInterval splitChild in interval.SplitChilds())
+ foreach (LiveInterval splitChild in interval.SplitChildren())
{
int splitPosition = splitChild.GetStart();
@@ -607,6 +632,12 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
return block.Index >= blocksCount;
}
+ // Reset iterators to beginning because GetSplitChild depends on the state of the iterator.
+ foreach (LiveInterval interval in _intervals)
+ {
+ interval.Reset();
+ }
+
for (BasicBlock block = cfg.Blocks.First; block != null; block = block.ListNext)
{
if (IsSplitEdgeBlock(block))
@@ -629,7 +660,7 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
succIndex = successor.GetSuccessor(0).Index;
}
- CopyResolver copyResolver = new CopyResolver();
+ CopyResolver copyResolver = null;
foreach (int iIndex in _blockLiveIn[succIndex])
{
@@ -646,13 +677,18 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
LiveInterval left = interval.GetSplitChild(lEnd);
LiveInterval right = interval.GetSplitChild(rStart);
- if (left != null && right != null && left != right)
+ if (left != default && right != default && left != right)
{
+ if (copyResolver == null)
+ {
+ copyResolver = new CopyResolver();
+ }
+
copyResolver.AddSplit(left, right);
}
}
- if (!copyResolver.HasCopy)
+ if (copyResolver == null || !copyResolver.HasCopy)
{
continue;
}
@@ -699,10 +735,8 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
{
Operand register = GetRegister(current);
- IList<int> usePositions = current.UsePositions();
- for (int i = usePositions.Count - 1; i >= 0; i--)
+ foreach (int usePosition in current.UsePositions())
{
- int usePosition = -usePositions[i];
(_, Operation operation) = GetOperationNode(usePosition);
for (int index = 0; index < operation.SourcesCount; index++)
@@ -759,7 +793,6 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
private void NumberLocals(ControlFlowGraph cfg)
{
_operationNodes = new List<(IntrusiveList<Operation>, Operation)>();
-
_intervals = new List<LiveInterval>();
for (int index = 0; index < RegistersCount; index++)
@@ -768,7 +801,18 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
_intervals.Add(new LiveInterval(new Register(index, RegisterType.Vector)));
}
- HashSet<Operand> visited = new HashSet<Operand>();
+ // The "visited" state is stored in the MSB of the local's value.
+ const ulong VisitedMask = 1ul << 63;
+
+ bool IsVisited(Operand local)
+ {
+ return (local.GetValueUnsafe() & VisitedMask) != 0;
+ }
+
+ void SetVisited(Operand local)
+ {
+ local.GetValueUnsafe() |= VisitedMask;
+ }
_operationsCount = 0;
@@ -784,11 +828,13 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
{
Operand dest = node.GetDestination(i);
- if (dest.Kind == OperandKind.LocalVariable && visited.Add(dest))
+ if (dest.Kind == OperandKind.LocalVariable && !IsVisited(dest))
{
dest.NumberLocal(_intervals.Count);
_intervals.Add(new LiveInterval(dest));
+
+ SetVisited(dest);
}
}
}
@@ -824,19 +870,45 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
for (Operation node = block.Operations.First; node != default; node = node.ListNext)
{
- Sources(node, (source) =>
+ for (int i = 0; i < node.SourcesCount; i++)
{
- int id = GetOperandId(source);
+ VisitSource(node.GetSource(i));
+ }
- if (!liveKill.IsSet(id))
+ for (int i = 0; i < node.DestinationsCount; i++)
+ {
+ VisitDestination(node.GetDestination(i));
+ }
+
+ void VisitSource(Operand source)
+ {
+ if (IsLocalOrRegister(source.Kind))
+ {
+ int id = GetOperandId(source);
+
+ if (!liveKill.IsSet(id))
+ {
+ liveGen.Set(id);
+ }
+ }
+ else if (source.Kind == OperandKind.Memory)
{
- liveGen.Set(id);
+ MemoryOperand memOp = source.GetMemory();
+
+ if (memOp.BaseAddress != default)
+ {
+ VisitSource(memOp.BaseAddress);
+ }
+
+ if (memOp.Index != default)
+ {
+ VisitSource(memOp.Index);
+ }
}
- });
+ }
- for (int i = 0; i < node.DestinationsCount; i++)
+ void VisitDestination(Operand dest)
{
- Operand dest = node.GetDestination(i);
liveKill.Set(GetOperandId(dest));
}
}
@@ -920,34 +992,65 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
continue;
}
- foreach (Operation node in BottomOperations(block))
+ for (Operation node = block.Operations.Last; node != default; node = node.ListPrevious)
{
operationPos -= InstructionGap;
for (int i = 0; i < node.DestinationsCount; i++)
{
- Operand dest = node.GetDestination(i);
- LiveInterval interval = _intervals[GetOperandId(dest)];
-
- interval.SetStart(operationPos + 1);
- interval.AddUsePosition(operationPos + 1);
+ VisitDestination(node.GetDestination(i));
}
- Sources(node, (source) =>
+ for (int i = 0; i < node.SourcesCount; i++)
{
- LiveInterval interval = _intervals[GetOperandId(source)];
-
- interval.AddRange(blockStart, operationPos + 1);
- interval.AddUsePosition(operationPos);
- });
+ VisitSource(node.GetSource(i));
+ }
if (node.Instruction == Instruction.Call)
{
AddIntervalCallerSavedReg(context.Masks.IntCallerSavedRegisters, operationPos, RegisterType.Integer);
AddIntervalCallerSavedReg(context.Masks.VecCallerSavedRegisters, operationPos, RegisterType.Vector);
}
+
+ void VisitSource(Operand source)
+ {
+ if (IsLocalOrRegister(source.Kind))
+ {
+ LiveInterval interval = _intervals[GetOperandId(source)];
+
+ interval.AddRange(blockStart, operationPos + 1);
+ interval.AddUsePosition(operationPos);
+ }
+ else if (source.Kind == OperandKind.Memory)
+ {
+ MemoryOperand memOp = source.GetMemory();
+
+ if (memOp.BaseAddress != default)
+ {
+ VisitSource(memOp.BaseAddress);
+ }
+
+ if (memOp.Index != default)
+ {
+ VisitSource(memOp.Index);
+ }
+ }
+ }
+
+ void VisitDestination(Operand dest)
+ {
+ LiveInterval interval = _intervals[GetOperandId(dest)];
+
+ interval.SetStart(operationPos + 1);
+ interval.AddUsePosition(operationPos + 1);
+ }
}
}
+
+ foreach (LiveInterval interval in _parentIntervals)
+ {
+ interval.Reset();
+ }
}
private void AddIntervalCallerSavedReg(int mask, int operationPos, RegisterType regType)
@@ -987,45 +1090,6 @@ namespace ARMeilleure.CodeGen.RegisterAllocators
return (register.Index << 1) | (register.Type == RegisterType.Vector ? 1 : 0);
}
- private static IEnumerable<Operation> BottomOperations(BasicBlock block)
- {
- Operation node = block.Operations.Last;
-
- while (node != default)
- {
- yield return node;
-
- node = node.ListPrevious;
- }
- }
-
- private static void Sources(Operation node, Action<Operand> action)
- {
- for (int index = 0; index < node.SourcesCount; index++)
- {
- Operand source = node.GetSource(index);
-
- if (IsLocalOrRegister(source.Kind))
- {
- action(source);
- }
- else if (source.Kind == OperandKind.Memory)
- {
- MemoryOperand memOp = source.GetMemory();
-
- if (memOp.BaseAddress != default)
- {
- action(memOp.BaseAddress);
- }
-
- if (memOp.Index != default)
- {
- action(memOp.Index);
- }
- }
- }
- }
-
private static bool IsLocalOrRegister(OperandKind kind)
{
return kind == OperandKind.LocalVariable ||