aboutsummaryrefslogtreecommitdiff
path: root/src/Ryujinx.Graphics.Shader/Translation
diff options
context:
space:
mode:
Diffstat (limited to 'src/Ryujinx.Graphics.Shader/Translation')
-rw-r--r--src/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs28
-rw-r--r--src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs1
-rw-r--r--src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs145
-rw-r--r--src/Ryujinx.Graphics.Shader/Translation/HelperFunctionName.cs6
-rw-r--r--src/Ryujinx.Graphics.Shader/Translation/Transforms/ShufflePass.cs52
-rw-r--r--src/Ryujinx.Graphics.Shader/Translation/Transforms/TransformPasses.cs1
6 files changed, 231 insertions, 2 deletions
diff --git a/src/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs b/src/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs
index 6cb57238..a08c8ea9 100644
--- a/src/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs
+++ b/src/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs
@@ -112,9 +112,13 @@ namespace Ryujinx.Graphics.Shader.Translation
return context.Add(Instruction.AtomicXor, storageKind, Local(), Const(binding), e0, e1, value);
}
- public static Operand Ballot(this EmitterContext context, Operand a)
+ public static Operand Ballot(this EmitterContext context, Operand a, int index)
{
- return context.Add(Instruction.Ballot, Local(), a);
+ Operand dest = Local();
+
+ context.Add(new Operation(Instruction.Ballot, index, dest, a));
+
+ return dest;
}
public static Operand Barrier(this EmitterContext context)
@@ -782,21 +786,41 @@ namespace Ryujinx.Graphics.Shader.Translation
return context.Add(Instruction.ShiftRightU32, Local(), a, b);
}
+ public static Operand Shuffle(this EmitterContext context, Operand a, Operand b)
+ {
+ return context.Add(Instruction.Shuffle, Local(), a, b);
+ }
+
public static (Operand, Operand) Shuffle(this EmitterContext context, Operand a, Operand b, Operand c)
{
return context.Add(Instruction.Shuffle, (Local(), Local()), a, b, c);
}
+ public static Operand ShuffleDown(this EmitterContext context, Operand a, Operand b)
+ {
+ return context.Add(Instruction.ShuffleDown, Local(), a, b);
+ }
+
public static (Operand, Operand) ShuffleDown(this EmitterContext context, Operand a, Operand b, Operand c)
{
return context.Add(Instruction.ShuffleDown, (Local(), Local()), a, b, c);
}
+ public static Operand ShuffleUp(this EmitterContext context, Operand a, Operand b)
+ {
+ return context.Add(Instruction.ShuffleUp, Local(), a, b);
+ }
+
public static (Operand, Operand) ShuffleUp(this EmitterContext context, Operand a, Operand b, Operand c)
{
return context.Add(Instruction.ShuffleUp, (Local(), Local()), a, b, c);
}
+ public static Operand ShuffleXor(this EmitterContext context, Operand a, Operand b)
+ {
+ return context.Add(Instruction.ShuffleXor, Local(), a, b);
+ }
+
public static (Operand, Operand) ShuffleXor(this EmitterContext context, Operand a, Operand b, Operand c)
{
return context.Add(Instruction.ShuffleXor, (Local(), Local()), a, b, c);
diff --git a/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs b/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs
index 5b7226ac..552a3f31 100644
--- a/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs
+++ b/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs
@@ -18,6 +18,7 @@ namespace Ryujinx.Graphics.Shader.Translation
InstanceId = 1 << 3,
DrawParameters = 1 << 4,
RtLayer = 1 << 5,
+ Shuffle = 1 << 6,
FixedFuncAttr = 1 << 9,
LocalMemory = 1 << 10,
SharedMemory = 1 << 11,
diff --git a/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs b/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs
index 2addff5c..ef2f8759 100644
--- a/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs
+++ b/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs
@@ -56,6 +56,20 @@ namespace Ryujinx.Graphics.Shader.Translation
return functionId;
}
+ public int GetOrCreateShuffleFunctionId(HelperFunctionName functionName, int subgroupSize)
+ {
+ if (_functionIds.TryGetValue((int)functionName, out int functionId))
+ {
+ return functionId;
+ }
+
+ Function function = GenerateShuffleFunction(functionName, subgroupSize);
+ functionId = AddFunction(function);
+ _functionIds.Add((int)functionName, functionId);
+
+ return functionId;
+ }
+
private Function GenerateFunction(HelperFunctionName functionName)
{
return functionName switch
@@ -216,6 +230,137 @@ namespace Ryujinx.Graphics.Shader.Translation
return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, $"SharedStore{bitSize}_{id}", false, 2, 0);
}
+ private static Function GenerateShuffleFunction(HelperFunctionName functionName, int subgroupSize)
+ {
+ return functionName switch
+ {
+ HelperFunctionName.Shuffle => GenerateShuffle(subgroupSize),
+ HelperFunctionName.ShuffleDown => GenerateShuffleDown(subgroupSize),
+ HelperFunctionName.ShuffleUp => GenerateShuffleUp(subgroupSize),
+ HelperFunctionName.ShuffleXor => GenerateShuffleXor(subgroupSize),
+ _ => throw new ArgumentException($"Invalid function name {functionName}"),
+ };
+ }
+
+ private static Function GenerateShuffle(int subgroupSize)
+ {
+ EmitterContext context = new();
+
+ Operand value = Argument(0);
+ Operand index = Argument(1);
+ Operand mask = Argument(2);
+
+ Operand clamp = context.BitwiseAnd(mask, Const(0x1f));
+ Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
+ Operand minThreadId = context.BitwiseAnd(GenerateLoadSubgroupLaneId(context, subgroupSize), segMask);
+ Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId);
+ Operand srcThreadId = context.BitwiseOr(context.BitwiseAnd(index, context.BitwiseNot(segMask)), minThreadId);
+ Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId);
+
+ context.Copy(Argument(3), valid);
+
+ Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
+
+ context.Return(context.ConditionalSelect(valid, result, value));
+
+ return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "Shuffle", true, 3, 1);
+ }
+
+ private static Function GenerateShuffleDown(int subgroupSize)
+ {
+ EmitterContext context = new();
+
+ Operand value = Argument(0);
+ Operand index = Argument(1);
+ Operand mask = Argument(2);
+
+ Operand clamp = context.BitwiseAnd(mask, Const(0x1f));
+ Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
+ Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize);
+ Operand minThreadId = context.BitwiseAnd(laneId, segMask);
+ Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId);
+ Operand srcThreadId = context.IAdd(laneId, index);
+ Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId);
+
+ context.Copy(Argument(3), valid);
+
+ Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
+
+ context.Return(context.ConditionalSelect(valid, result, value));
+
+ return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleDown", true, 3, 1);
+ }
+
+ private static Function GenerateShuffleUp(int subgroupSize)
+ {
+ EmitterContext context = new();
+
+ Operand value = Argument(0);
+ Operand index = Argument(1);
+ Operand mask = Argument(2);
+
+ Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
+ Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize);
+ Operand minThreadId = context.BitwiseAnd(laneId, segMask);
+ Operand srcThreadId = context.ISubtract(laneId, index);
+ Operand valid = context.ICompareGreaterOrEqual(srcThreadId, minThreadId);
+
+ context.Copy(Argument(3), valid);
+
+ Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
+
+ context.Return(context.ConditionalSelect(valid, result, value));
+
+ return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleUp", true, 3, 1);
+ }
+
+ private static Function GenerateShuffleXor(int subgroupSize)
+ {
+ EmitterContext context = new();
+
+ Operand value = Argument(0);
+ Operand index = Argument(1);
+ Operand mask = Argument(2);
+
+ Operand clamp = context.BitwiseAnd(mask, Const(0x1f));
+ Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
+ Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize);
+ Operand minThreadId = context.BitwiseAnd(laneId, segMask);
+ Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId);
+ Operand srcThreadId = context.BitwiseExclusiveOr(laneId, index);
+ Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId);
+
+ context.Copy(Argument(3), valid);
+
+ Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
+
+ context.Return(context.ConditionalSelect(valid, result, value));
+
+ return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleXor", true, 3, 1);
+ }
+
+ private static Operand GenerateLoadSubgroupLaneId(EmitterContext context, int subgroupSize)
+ {
+ if (subgroupSize <= 32)
+ {
+ return context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
+ }
+
+ return context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x1f));
+ }
+
+ private static Operand GenerateSubgroupShuffleIndex(EmitterContext context, Operand srcThreadId, int subgroupSize)
+ {
+ if (subgroupSize <= 32)
+ {
+ return srcThreadId;
+ }
+
+ return context.BitwiseOr(
+ context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x60)),
+ srcThreadId);
+ }
+
private Function GenerateTexelFetchScaleFunction()
{
EmitterContext context = new();
diff --git a/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionName.cs b/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionName.cs
index e5af1735..09b17729 100644
--- a/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionName.cs
+++ b/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionName.cs
@@ -2,12 +2,18 @@ namespace Ryujinx.Graphics.Shader.Translation
{
enum HelperFunctionName
{
+ Invalid,
+
ConvertDoubleToFloat,
ConvertFloatToDouble,
SharedAtomicMaxS32,
SharedAtomicMinS32,
SharedStore8,
SharedStore16,
+ Shuffle,
+ ShuffleDown,
+ ShuffleUp,
+ ShuffleXor,
TexelFetchScale,
TextureSizeUnscale,
}
diff --git a/src/Ryujinx.Graphics.Shader/Translation/Transforms/ShufflePass.cs b/src/Ryujinx.Graphics.Shader/Translation/Transforms/ShufflePass.cs
new file mode 100644
index 00000000..839d4f81
--- /dev/null
+++ b/src/Ryujinx.Graphics.Shader/Translation/Transforms/ShufflePass.cs
@@ -0,0 +1,52 @@
+using Ryujinx.Graphics.Shader.IntermediateRepresentation;
+using Ryujinx.Graphics.Shader.Translation.Optimizations;
+using System.Collections.Generic;
+using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;
+
+namespace Ryujinx.Graphics.Shader.Translation.Transforms
+{
+ class ShufflePass : ITransformPass
+ {
+ public static bool IsEnabled(IGpuAccessor gpuAccessor, ShaderStage stage, TargetLanguage targetLanguage, FeatureFlags usedFeatures)
+ {
+ return usedFeatures.HasFlag(FeatureFlags.Shuffle);
+ }
+
+ public static LinkedListNode<INode> RunPass(TransformContext context, LinkedListNode<INode> node)
+ {
+ Operation operation = (Operation)node.Value;
+
+ HelperFunctionName functionName = operation.Inst switch
+ {
+ Instruction.Shuffle => HelperFunctionName.Shuffle,
+ Instruction.ShuffleDown => HelperFunctionName.ShuffleDown,
+ Instruction.ShuffleUp => HelperFunctionName.ShuffleUp,
+ Instruction.ShuffleXor => HelperFunctionName.ShuffleXor,
+ _ => HelperFunctionName.Invalid,
+ };
+
+ if (functionName == HelperFunctionName.Invalid || operation.SourcesCount != 3 || operation.DestsCount != 2)
+ {
+ return node;
+ }
+
+ int functionId = context.Hfm.GetOrCreateShuffleFunctionId(functionName, context.GpuAccessor.QueryHostSubgroupSize());
+
+ Operand result = operation.GetDest(0);
+ Operand valid = operation.GetDest(1);
+ Operand value = operation.GetSource(0);
+ Operand index = operation.GetSource(1);
+ Operand mask = operation.GetSource(2);
+
+ operation.Dest = null;
+
+ Operand[] callArgs = new Operand[] { Const(functionId), value, index, mask, valid };
+
+ LinkedListNode<INode> newNode = node.List.AddBefore(node, new Operation(Instruction.Call, 0, result, callArgs));
+
+ Utils.DeleteNode(node, operation);
+
+ return newNode;
+ }
+ }
+}
diff --git a/src/Ryujinx.Graphics.Shader/Translation/Transforms/TransformPasses.cs b/src/Ryujinx.Graphics.Shader/Translation/Transforms/TransformPasses.cs
index c3bbe7dd..29393880 100644
--- a/src/Ryujinx.Graphics.Shader/Translation/Transforms/TransformPasses.cs
+++ b/src/Ryujinx.Graphics.Shader/Translation/Transforms/TransformPasses.cs
@@ -13,6 +13,7 @@ namespace Ryujinx.Graphics.Shader.Translation.Transforms
RunPass<TexturePass>(context);
RunPass<SharedStoreSmallIntCas>(context);
RunPass<SharedAtomicSignedCas>(context);
+ RunPass<ShufflePass>(context);
}
private static void RunPass<T>(TransformContext context) where T : ITransformPass