From 6ed613a6e6a66d57d2fdb045d926e42dfcdd3206 Mon Sep 17 00:00:00 2001 From: gdkchan Date: Wed, 16 Aug 2023 21:31:07 -0300 Subject: Fix vote and shuffle shader instructions on AMD GPUs (#5540) * Move shuffle handling out of the backend to a transform pass * Handle subgroup sizes higher than 32 * Stop using the subgroup size control extension * Make GenerateShuffleFunction static * Shader cache version bump --- .../Translation/EmitterContextInsts.cs | 28 +++- .../Translation/FeatureFlags.cs | 1 + .../Translation/HelperFunctionManager.cs | 145 +++++++++++++++++++++ .../Translation/HelperFunctionName.cs | 6 + .../Translation/Transforms/ShufflePass.cs | 52 ++++++++ .../Translation/Transforms/TransformPasses.cs | 1 + 6 files changed, 231 insertions(+), 2 deletions(-) create mode 100644 src/Ryujinx.Graphics.Shader/Translation/Transforms/ShufflePass.cs (limited to 'src/Ryujinx.Graphics.Shader/Translation') diff --git a/src/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs b/src/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs index 6cb57238..a08c8ea9 100644 --- a/src/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs +++ b/src/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs @@ -112,9 +112,13 @@ namespace Ryujinx.Graphics.Shader.Translation return context.Add(Instruction.AtomicXor, storageKind, Local(), Const(binding), e0, e1, value); } - public static Operand Ballot(this EmitterContext context, Operand a) + public static Operand Ballot(this EmitterContext context, Operand a, int index) { - return context.Add(Instruction.Ballot, Local(), a); + Operand dest = Local(); + + context.Add(new Operation(Instruction.Ballot, index, dest, a)); + + return dest; } public static Operand Barrier(this EmitterContext context) @@ -782,21 +786,41 @@ namespace Ryujinx.Graphics.Shader.Translation return context.Add(Instruction.ShiftRightU32, Local(), a, b); } + public static Operand Shuffle(this EmitterContext context, Operand a, Operand b) + { + return context.Add(Instruction.Shuffle, Local(), a, b); + } + public static (Operand, Operand) Shuffle(this EmitterContext context, Operand a, Operand b, Operand c) { return context.Add(Instruction.Shuffle, (Local(), Local()), a, b, c); } + public static Operand ShuffleDown(this EmitterContext context, Operand a, Operand b) + { + return context.Add(Instruction.ShuffleDown, Local(), a, b); + } + public static (Operand, Operand) ShuffleDown(this EmitterContext context, Operand a, Operand b, Operand c) { return context.Add(Instruction.ShuffleDown, (Local(), Local()), a, b, c); } + public static Operand ShuffleUp(this EmitterContext context, Operand a, Operand b) + { + return context.Add(Instruction.ShuffleUp, Local(), a, b); + } + public static (Operand, Operand) ShuffleUp(this EmitterContext context, Operand a, Operand b, Operand c) { return context.Add(Instruction.ShuffleUp, (Local(), Local()), a, b, c); } + public static Operand ShuffleXor(this EmitterContext context, Operand a, Operand b) + { + return context.Add(Instruction.ShuffleXor, Local(), a, b); + } + public static (Operand, Operand) ShuffleXor(this EmitterContext context, Operand a, Operand b, Operand c) { return context.Add(Instruction.ShuffleXor, (Local(), Local()), a, b, c); diff --git a/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs b/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs index 5b7226ac..552a3f31 100644 --- a/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs +++ b/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs @@ -18,6 +18,7 @@ namespace Ryujinx.Graphics.Shader.Translation InstanceId = 1 << 3, DrawParameters = 1 << 4, RtLayer = 1 << 5, + Shuffle = 1 << 6, FixedFuncAttr = 1 << 9, LocalMemory = 1 << 10, SharedMemory = 1 << 11, diff --git a/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs b/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs index 2addff5c..ef2f8759 100644 --- a/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs +++ b/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs @@ -56,6 +56,20 @@ namespace Ryujinx.Graphics.Shader.Translation return functionId; } + public int GetOrCreateShuffleFunctionId(HelperFunctionName functionName, int subgroupSize) + { + if (_functionIds.TryGetValue((int)functionName, out int functionId)) + { + return functionId; + } + + Function function = GenerateShuffleFunction(functionName, subgroupSize); + functionId = AddFunction(function); + _functionIds.Add((int)functionName, functionId); + + return functionId; + } + private Function GenerateFunction(HelperFunctionName functionName) { return functionName switch @@ -216,6 +230,137 @@ namespace Ryujinx.Graphics.Shader.Translation return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, $"SharedStore{bitSize}_{id}", false, 2, 0); } + private static Function GenerateShuffleFunction(HelperFunctionName functionName, int subgroupSize) + { + return functionName switch + { + HelperFunctionName.Shuffle => GenerateShuffle(subgroupSize), + HelperFunctionName.ShuffleDown => GenerateShuffleDown(subgroupSize), + HelperFunctionName.ShuffleUp => GenerateShuffleUp(subgroupSize), + HelperFunctionName.ShuffleXor => GenerateShuffleXor(subgroupSize), + _ => throw new ArgumentException($"Invalid function name {functionName}"), + }; + } + + private static Function GenerateShuffle(int subgroupSize) + { + EmitterContext context = new(); + + Operand value = Argument(0); + Operand index = Argument(1); + Operand mask = Argument(2); + + Operand clamp = context.BitwiseAnd(mask, Const(0x1f)); + Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f)); + Operand minThreadId = context.BitwiseAnd(GenerateLoadSubgroupLaneId(context, subgroupSize), segMask); + Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId); + Operand srcThreadId = context.BitwiseOr(context.BitwiseAnd(index, context.BitwiseNot(segMask)), minThreadId); + Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId); + + context.Copy(Argument(3), valid); + + Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize)); + + context.Return(context.ConditionalSelect(valid, result, value)); + + return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "Shuffle", true, 3, 1); + } + + private static Function GenerateShuffleDown(int subgroupSize) + { + EmitterContext context = new(); + + Operand value = Argument(0); + Operand index = Argument(1); + Operand mask = Argument(2); + + Operand clamp = context.BitwiseAnd(mask, Const(0x1f)); + Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f)); + Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize); + Operand minThreadId = context.BitwiseAnd(laneId, segMask); + Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId); + Operand srcThreadId = context.IAdd(laneId, index); + Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId); + + context.Copy(Argument(3), valid); + + Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize)); + + context.Return(context.ConditionalSelect(valid, result, value)); + + return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleDown", true, 3, 1); + } + + private static Function GenerateShuffleUp(int subgroupSize) + { + EmitterContext context = new(); + + Operand value = Argument(0); + Operand index = Argument(1); + Operand mask = Argument(2); + + Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f)); + Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize); + Operand minThreadId = context.BitwiseAnd(laneId, segMask); + Operand srcThreadId = context.ISubtract(laneId, index); + Operand valid = context.ICompareGreaterOrEqual(srcThreadId, minThreadId); + + context.Copy(Argument(3), valid); + + Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize)); + + context.Return(context.ConditionalSelect(valid, result, value)); + + return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleUp", true, 3, 1); + } + + private static Function GenerateShuffleXor(int subgroupSize) + { + EmitterContext context = new(); + + Operand value = Argument(0); + Operand index = Argument(1); + Operand mask = Argument(2); + + Operand clamp = context.BitwiseAnd(mask, Const(0x1f)); + Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f)); + Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize); + Operand minThreadId = context.BitwiseAnd(laneId, segMask); + Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId); + Operand srcThreadId = context.BitwiseExclusiveOr(laneId, index); + Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId); + + context.Copy(Argument(3), valid); + + Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize)); + + context.Return(context.ConditionalSelect(valid, result, value)); + + return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleXor", true, 3, 1); + } + + private static Operand GenerateLoadSubgroupLaneId(EmitterContext context, int subgroupSize) + { + if (subgroupSize <= 32) + { + return context.Load(StorageKind.Input, IoVariable.SubgroupLaneId); + } + + return context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x1f)); + } + + private static Operand GenerateSubgroupShuffleIndex(EmitterContext context, Operand srcThreadId, int subgroupSize) + { + if (subgroupSize <= 32) + { + return srcThreadId; + } + + return context.BitwiseOr( + context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x60)), + srcThreadId); + } + private Function GenerateTexelFetchScaleFunction() { EmitterContext context = new(); diff --git a/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionName.cs b/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionName.cs index e5af1735..09b17729 100644 --- a/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionName.cs +++ b/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionName.cs @@ -2,12 +2,18 @@ namespace Ryujinx.Graphics.Shader.Translation { enum HelperFunctionName { + Invalid, + ConvertDoubleToFloat, ConvertFloatToDouble, SharedAtomicMaxS32, SharedAtomicMinS32, SharedStore8, SharedStore16, + Shuffle, + ShuffleDown, + ShuffleUp, + ShuffleXor, TexelFetchScale, TextureSizeUnscale, } diff --git a/src/Ryujinx.Graphics.Shader/Translation/Transforms/ShufflePass.cs b/src/Ryujinx.Graphics.Shader/Translation/Transforms/ShufflePass.cs new file mode 100644 index 00000000..839d4f81 --- /dev/null +++ b/src/Ryujinx.Graphics.Shader/Translation/Transforms/ShufflePass.cs @@ -0,0 +1,52 @@ +using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using Ryujinx.Graphics.Shader.Translation.Optimizations; +using System.Collections.Generic; +using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper; + +namespace Ryujinx.Graphics.Shader.Translation.Transforms +{ + class ShufflePass : ITransformPass + { + public static bool IsEnabled(IGpuAccessor gpuAccessor, ShaderStage stage, TargetLanguage targetLanguage, FeatureFlags usedFeatures) + { + return usedFeatures.HasFlag(FeatureFlags.Shuffle); + } + + public static LinkedListNode RunPass(TransformContext context, LinkedListNode node) + { + Operation operation = (Operation)node.Value; + + HelperFunctionName functionName = operation.Inst switch + { + Instruction.Shuffle => HelperFunctionName.Shuffle, + Instruction.ShuffleDown => HelperFunctionName.ShuffleDown, + Instruction.ShuffleUp => HelperFunctionName.ShuffleUp, + Instruction.ShuffleXor => HelperFunctionName.ShuffleXor, + _ => HelperFunctionName.Invalid, + }; + + if (functionName == HelperFunctionName.Invalid || operation.SourcesCount != 3 || operation.DestsCount != 2) + { + return node; + } + + int functionId = context.Hfm.GetOrCreateShuffleFunctionId(functionName, context.GpuAccessor.QueryHostSubgroupSize()); + + Operand result = operation.GetDest(0); + Operand valid = operation.GetDest(1); + Operand value = operation.GetSource(0); + Operand index = operation.GetSource(1); + Operand mask = operation.GetSource(2); + + operation.Dest = null; + + Operand[] callArgs = new Operand[] { Const(functionId), value, index, mask, valid }; + + LinkedListNode newNode = node.List.AddBefore(node, new Operation(Instruction.Call, 0, result, callArgs)); + + Utils.DeleteNode(node, operation); + + return newNode; + } + } +} diff --git a/src/Ryujinx.Graphics.Shader/Translation/Transforms/TransformPasses.cs b/src/Ryujinx.Graphics.Shader/Translation/Transforms/TransformPasses.cs index c3bbe7dd..29393880 100644 --- a/src/Ryujinx.Graphics.Shader/Translation/Transforms/TransformPasses.cs +++ b/src/Ryujinx.Graphics.Shader/Translation/Transforms/TransformPasses.cs @@ -13,6 +13,7 @@ namespace Ryujinx.Graphics.Shader.Translation.Transforms RunPass(context); RunPass(context); RunPass(context); + RunPass(context); } private static void RunPass(TransformContext context) where T : ITransformPass -- cgit v1.2.3