diff options
| author | gdkchan <gab.dark.100@gmail.com> | 2020-10-25 17:00:44 -0300 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-10-25 17:00:44 -0300 |
| commit | 49f970d5bd9163e2b4e26a33ef8f84529174d5de (patch) | |
| tree | eceaf9c0454d27413ca77689c06a24b47467d1a0 /Ryujinx.Graphics.Shader/Translation | |
| parent | 973a615d405a83d5fc2f6a11ad12ba63c2a76465 (diff) | |
Implement CAL and RET shader instructions (#1618)
* Add support for CAL and RET shader instructions
* Remove unused stuff
* Fix a bug that could cause the wrong values to be passed to a function
* Avoid repopulating function id dictionary every time
* PR feedback
* Fix vertex shader A/B merge
Diffstat (limited to 'Ryujinx.Graphics.Shader/Translation')
7 files changed, 735 insertions, 140 deletions
diff --git a/Ryujinx.Graphics.Shader/Translation/ControlFlowGraph.cs b/Ryujinx.Graphics.Shader/Translation/ControlFlowGraph.cs index e2ca74a4..fb0535c8 100644 --- a/Ryujinx.Graphics.Shader/Translation/ControlFlowGraph.cs +++ b/Ryujinx.Graphics.Shader/Translation/ControlFlowGraph.cs @@ -3,9 +3,52 @@ using System.Collections.Generic; namespace Ryujinx.Graphics.Shader.Translation { - static class ControlFlowGraph + class ControlFlowGraph { - public static BasicBlock[] MakeCfg(Operation[] operations) + public BasicBlock[] Blocks { get; } + public BasicBlock[] PostOrderBlocks { get; } + public int[] PostOrderMap { get; } + + public ControlFlowGraph(BasicBlock[] blocks) + { + Blocks = blocks; + + HashSet<BasicBlock> visited = new HashSet<BasicBlock>(); + + Stack<BasicBlock> blockStack = new Stack<BasicBlock>(); + + List<BasicBlock> postOrderBlocks = new List<BasicBlock>(blocks.Length); + + PostOrderMap = new int[blocks.Length]; + + visited.Add(blocks[0]); + + blockStack.Push(blocks[0]); + + while (blockStack.TryPop(out BasicBlock block)) + { + if (block.Next != null && visited.Add(block.Next)) + { + blockStack.Push(block); + blockStack.Push(block.Next); + } + else if (block.Branch != null && visited.Add(block.Branch)) + { + blockStack.Push(block); + blockStack.Push(block.Branch); + } + else + { + PostOrderMap[block.Index] = postOrderBlocks.Count; + + postOrderBlocks.Add(block); + } + } + + PostOrderBlocks = postOrderBlocks.ToArray(); + } + + public static ControlFlowGraph Create(Operation[] operations) { Dictionary<Operand, BasicBlock> labels = new Dictionary<Operand, BasicBlock>(); @@ -86,7 +129,7 @@ namespace Ryujinx.Graphics.Shader.Translation } } - return blocks.ToArray(); + return new ControlFlowGraph(blocks.ToArray()); } private static bool EndsWithUnconditionalInst(INode node) diff --git a/Ryujinx.Graphics.Shader/Translation/Dominance.cs b/Ryujinx.Graphics.Shader/Translation/Dominance.cs index 6a3ff35f..da4a38da 100644 --- a/Ryujinx.Graphics.Shader/Translation/Dominance.cs +++ b/Ryujinx.Graphics.Shader/Translation/Dominance.cs @@ -7,50 +7,18 @@ namespace Ryujinx.Graphics.Shader.Translation { // Those methods are an implementation of the algorithms on "A Simple, Fast Dominance Algorithm". // https://www.cs.rice.edu/~keith/EMBED/dom.pdf - public static void FindDominators(BasicBlock entry, int blocksCount) + public static void FindDominators(ControlFlowGraph cfg) { - HashSet<BasicBlock> visited = new HashSet<BasicBlock>(); - - Stack<BasicBlock> blockStack = new Stack<BasicBlock>(); - - List<BasicBlock> postOrderBlocks = new List<BasicBlock>(blocksCount); - - int[] postOrderMap = new int[blocksCount]; - - visited.Add(entry); - - blockStack.Push(entry); - - while (blockStack.TryPop(out BasicBlock block)) - { - if (block.Next != null && visited.Add(block.Next)) - { - blockStack.Push(block); - blockStack.Push(block.Next); - } - else if (block.Branch != null && visited.Add(block.Branch)) - { - blockStack.Push(block); - blockStack.Push(block.Branch); - } - else - { - postOrderMap[block.Index] = postOrderBlocks.Count; - - postOrderBlocks.Add(block); - } - } - BasicBlock Intersect(BasicBlock block1, BasicBlock block2) { while (block1 != block2) { - while (postOrderMap[block1.Index] < postOrderMap[block2.Index]) + while (cfg.PostOrderMap[block1.Index] < cfg.PostOrderMap[block2.Index]) { block1 = block1.ImmediateDominator; } - while (postOrderMap[block2.Index] < postOrderMap[block1.Index]) + while (cfg.PostOrderMap[block2.Index] < cfg.PostOrderMap[block1.Index]) { block2 = block2.ImmediateDominator; } @@ -59,7 +27,7 @@ namespace Ryujinx.Graphics.Shader.Translation return block1; } - entry.ImmediateDominator = entry; + cfg.Blocks[0].ImmediateDominator = cfg.Blocks[0]; bool modified; @@ -67,9 +35,9 @@ namespace Ryujinx.Graphics.Shader.Translation { modified = false; - for (int blkIndex = postOrderBlocks.Count - 2; blkIndex >= 0; blkIndex--) + for (int blkIndex = cfg.PostOrderBlocks.Length - 2; blkIndex >= 0; blkIndex--) { - BasicBlock block = postOrderBlocks[blkIndex]; + BasicBlock block = cfg.PostOrderBlocks[blkIndex]; BasicBlock newIDom = null; diff --git a/Ryujinx.Graphics.Shader/Translation/EmitterContext.cs b/Ryujinx.Graphics.Shader/Translation/EmitterContext.cs index c5ebe9e7..d5d30f12 100644 --- a/Ryujinx.Graphics.Shader/Translation/EmitterContext.cs +++ b/Ryujinx.Graphics.Shader/Translation/EmitterContext.cs @@ -13,16 +13,18 @@ namespace Ryujinx.Graphics.Shader.Translation public ShaderConfig Config { get; } - private List<Operation> _operations; + public bool IsNonMain { get; } - private Dictionary<ulong, Operand> _labels; + private readonly IReadOnlyDictionary<ulong, int> _funcs; + private readonly List<Operation> _operations; + private readonly Dictionary<ulong, Operand> _labels; - public EmitterContext(ShaderConfig config) + public EmitterContext(ShaderConfig config, bool isNonMain, IReadOnlyDictionary<ulong, int> funcs) { Config = config; - + IsNonMain = isNonMain; + _funcs = funcs; _operations = new List<Operation>(); - _labels = new Dictionary<ulong, Operand>(); } @@ -71,6 +73,11 @@ namespace Ryujinx.Graphics.Shader.Translation return label; } + public int GetFunctionId(ulong address) + { + return _funcs[address]; + } + public void PrepareForReturn() { if (Config.Stage == ShaderStage.Fragment) diff --git a/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs b/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs index c8d622b2..40f3370f 100644 --- a/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs +++ b/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs @@ -136,6 +136,16 @@ namespace Ryujinx.Graphics.Shader.Translation return context.Add(Instruction.BranchIfTrue, d, a); } + public static Operand Call(this EmitterContext context, int funcId, bool returns, params Operand[] args) + { + Operand[] args2 = new Operand[args.Length + 1]; + + args2[0] = Const(funcId); + args.CopyTo(args2, 1); + + return context.Add(Instruction.Call, returns ? Local() : null, args2); + } + public static Operand ConditionalSelect(this EmitterContext context, Operand a, Operand b, Operand c) { return context.Add(Instruction.ConditionalSelect, Local(), a, b, c); @@ -521,11 +531,16 @@ namespace Ryujinx.Graphics.Shader.Translation return context.Add(Instruction.PackHalf2x16, Local(), a, b); } - public static Operand Return(this EmitterContext context) + public static void Return(this EmitterContext context) { context.PrepareForReturn(); + context.Add(Instruction.Return); + } - return context.Add(Instruction.Return); + public static void Return(this EmitterContext context, Operand returnValue) + { + context.PrepareForReturn(); + context.Add(Instruction.Return, null, returnValue); } public static Operand ShiftLeft(this EmitterContext context, Operand a, Operand b) diff --git a/Ryujinx.Graphics.Shader/Translation/Optimizations/Optimizer.cs b/Ryujinx.Graphics.Shader/Translation/Optimizations/Optimizer.cs index 286574cf..32c7d2f0 100644 --- a/Ryujinx.Graphics.Shader/Translation/Optimizations/Optimizer.cs +++ b/Ryujinx.Graphics.Shader/Translation/Optimizations/Optimizer.cs @@ -287,6 +287,8 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations case Instruction.AtomicOr: case Instruction.AtomicSwap: case Instruction.AtomicXor: + case Instruction.Call: + case Instruction.CallOutArgument: return true; } } diff --git a/Ryujinx.Graphics.Shader/Translation/RegisterUsage.cs b/Ryujinx.Graphics.Shader/Translation/RegisterUsage.cs new file mode 100644 index 00000000..fd90391f --- /dev/null +++ b/Ryujinx.Graphics.Shader/Translation/RegisterUsage.cs @@ -0,0 +1,484 @@ +using Ryujinx.Graphics.Shader.Decoders; +using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Numerics; + +namespace Ryujinx.Graphics.Shader.Translation +{ + static class RegisterUsage + { + private const int RegsCount = 256; + private const int RegsMask = RegsCount - 1; + + private const int GprMasks = 4; + private const int PredMasks = 1; + private const int FlagMasks = 1; + private const int TotalMasks = GprMasks + PredMasks + FlagMasks; + + private struct RegisterMask : IEquatable<RegisterMask> + { + public long GprMask0 { get; set; } + public long GprMask1 { get; set; } + public long GprMask2 { get; set; } + public long GprMask3 { get; set; } + public long PredMask { get; set; } + public long FlagMask { get; set; } + + public RegisterMask(long gprMask0, long gprMask1, long gprMask2, long gprMask3, long predMask, long flagMask) + { + GprMask0 = gprMask0; + GprMask1 = gprMask1; + GprMask2 = gprMask2; + GprMask3 = gprMask3; + PredMask = predMask; + FlagMask = flagMask; + } + + public long GetMask(int index) + { + return index switch + { + 0 => GprMask0, + 1 => GprMask1, + 2 => GprMask2, + 3 => GprMask3, + 4 => PredMask, + 5 => FlagMask, + _ => throw new ArgumentOutOfRangeException(nameof(index)) + }; + } + + public static RegisterMask operator &(RegisterMask x, RegisterMask y) + { + return new RegisterMask( + x.GprMask0 & y.GprMask0, + x.GprMask1 & y.GprMask1, + x.GprMask2 & y.GprMask2, + x.GprMask3 & y.GprMask3, + x.PredMask & y.PredMask, + x.FlagMask & y.FlagMask); + } + + public static RegisterMask operator |(RegisterMask x, RegisterMask y) + { + return new RegisterMask( + x.GprMask0 | y.GprMask0, + x.GprMask1 | y.GprMask1, + x.GprMask2 | y.GprMask2, + x.GprMask3 | y.GprMask3, + x.PredMask | y.PredMask, + x.FlagMask | y.FlagMask); + } + + public static RegisterMask operator ~(RegisterMask x) + { + return new RegisterMask( + ~x.GprMask0, + ~x.GprMask1, + ~x.GprMask2, + ~x.GprMask3, + ~x.PredMask, + ~x.FlagMask); + } + + public static bool operator ==(RegisterMask x, RegisterMask y) + { + return x.Equals(y); + } + + public static bool operator !=(RegisterMask x, RegisterMask y) + { + return !x.Equals(y); + } + + public override bool Equals(object obj) + { + return obj is RegisterMask regMask && Equals(regMask); + } + + public bool Equals(RegisterMask other) + { + return GprMask0 == other.GprMask0 && + GprMask1 == other.GprMask1 && + GprMask2 == other.GprMask2 && + GprMask3 == other.GprMask3 && + PredMask == other.PredMask && + FlagMask == other.FlagMask; + } + + public override int GetHashCode() + { + return HashCode.Combine(GprMask0, GprMask1, GprMask2, GprMask3, PredMask, FlagMask); + } + } + + public struct FunctionRegisterUsage + { + public Register[] InArguments { get; } + public Register[] OutArguments { get; } + + public FunctionRegisterUsage(Register[] inArguments, Register[] outArguments) + { + InArguments = inArguments; + OutArguments = outArguments; + } + } + + public static FunctionRegisterUsage RunPass(ControlFlowGraph cfg) + { + List<Register> inArguments = new List<Register>(); + List<Register> outArguments = new List<Register>(); + + // Compute local register inputs and outputs used inside blocks. + RegisterMask[] localInputs = new RegisterMask[cfg.Blocks.Length]; + RegisterMask[] localOutputs = new RegisterMask[cfg.Blocks.Length]; + + foreach (BasicBlock block in cfg.Blocks) + { + for (LinkedListNode<INode> node = block.Operations.First; node != null; node = node.Next) + { + Operation operation = node.Value as Operation; + + for (int srcIndex = 0; srcIndex < operation.SourcesCount; srcIndex++) + { + Operand source = operation.GetSource(srcIndex); + + if (source.Type != OperandType.Register) + { + continue; + } + + Register register = source.GetRegister(); + + localInputs[block.Index] |= GetMask(register) & ~localOutputs[block.Index]; + } + + if (operation.Dest != null && operation.Dest.Type == OperandType.Register) + { + localOutputs[block.Index] |= GetMask(operation.Dest.GetRegister()); + } + } + } + + // Compute global register inputs and outputs used across blocks. + RegisterMask[] globalCmnOutputs = new RegisterMask[cfg.Blocks.Length]; + + RegisterMask[] globalInputs = new RegisterMask[cfg.Blocks.Length]; + RegisterMask[] globalOutputs = new RegisterMask[cfg.Blocks.Length]; + + RegisterMask allOutputs = new RegisterMask(); + RegisterMask allCmnOutputs = new RegisterMask(-1L, -1L, -1L, -1L, -1L, -1L); + + bool modified; + + bool firstPass = true; + + do + { + modified = false; + + // Compute register outputs. + for (int index = cfg.PostOrderBlocks.Length - 1; index >= 0; index--) + { + BasicBlock block = cfg.PostOrderBlocks[index]; + + if (block.Predecessors.Count != 0) + { + BasicBlock predecessor = block.Predecessors[0]; + + RegisterMask cmnOutputs = localOutputs[predecessor.Index] | globalCmnOutputs[predecessor.Index]; + + RegisterMask outputs = globalOutputs[predecessor.Index]; + + for (int pIndex = 1; pIndex < block.Predecessors.Count; pIndex++) + { + predecessor = block.Predecessors[pIndex]; + + cmnOutputs &= localOutputs[predecessor.Index] | globalCmnOutputs[predecessor.Index]; + + outputs |= globalOutputs[predecessor.Index]; + } + + globalInputs[block.Index] |= outputs & ~cmnOutputs; + + if (!firstPass) + { + cmnOutputs &= globalCmnOutputs[block.Index]; + } + + if (EndsWithReturn(block)) + { + allCmnOutputs &= cmnOutputs | localOutputs[block.Index]; + } + + if (Exchange(globalCmnOutputs, block.Index, cmnOutputs)) + { + modified = true; + } + + outputs |= localOutputs[block.Index]; + + if (Exchange(globalOutputs, block.Index, globalOutputs[block.Index] | outputs)) + { + allOutputs |= outputs; + modified = true; + } + } + else if (Exchange(globalOutputs, block.Index, localOutputs[block.Index])) + { + allOutputs |= localOutputs[block.Index]; + modified = true; + } + } + + // Compute register inputs. + for (int index = 0; index < cfg.PostOrderBlocks.Length; index++) + { + BasicBlock block = cfg.PostOrderBlocks[index]; + + RegisterMask inputs = localInputs[block.Index]; + + if (block.Next != null) + { + inputs |= globalInputs[block.Next.Index]; + } + + if (block.Branch != null) + { + inputs |= globalInputs[block.Branch.Index]; + } + + inputs &= ~globalCmnOutputs[block.Index]; + + if (Exchange(globalInputs, block.Index, globalInputs[block.Index] | inputs)) + { + modified = true; + } + } + + firstPass = false; + } + while (modified); + + // Insert load and store context instructions where needed. + foreach (BasicBlock block in cfg.Blocks) + { + // The only block without any predecessor should be the entry block. + // It always needs a context load as it is the first block to run. + if (block.Predecessors.Count == 0) + { + RegisterMask inputs = globalInputs[block.Index] | (allOutputs & ~allCmnOutputs); + + LoadLocals(block, inputs, inArguments); + } + + if (EndsWithReturn(block)) + { + StoreLocals(block, allOutputs, inArguments.Count, outArguments); + } + } + + return new FunctionRegisterUsage(inArguments.ToArray(), outArguments.ToArray()); + } + + public static void FixupCalls(BasicBlock[] blocks, FunctionRegisterUsage[] frus) + { + foreach (BasicBlock block in blocks) + { + for (LinkedListNode<INode> node = block.Operations.First; node != null; node = node.Next) + { + Operation operation = node.Value as Operation; + + if (operation.Inst == Instruction.Call) + { + Operand funcId = operation.GetSource(0); + + Debug.Assert(funcId.Type == OperandType.Constant); + + var fru = frus[funcId.Value]; + + Operand[] regs = new Operand[fru.InArguments.Length]; + + for (int i = 0; i < fru.InArguments.Length; i++) + { + regs[i] = OperandHelper.Register(fru.InArguments[i]); + } + + operation.AppendOperands(regs); + + for (int i = 0; i < fru.OutArguments.Length; i++) + { + Operation callOutArgOp = new Operation(Instruction.CallOutArgument, OperandHelper.Register(fru.OutArguments[i])); + + node = block.Operations.AddAfter(node, callOutArgOp); + } + } + } + } + } + + private static bool StartsWith(BasicBlock block, Instruction inst) + { + if (block.Operations.Count == 0) + { + return false; + } + + return block.Operations.First.Value is Operation operation && operation.Inst == inst; + } + + private static bool EndsWith(BasicBlock block, Instruction inst) + { + if (block.Operations.Count == 0) + { + return false; + } + + return block.Operations.Last.Value is Operation operation && operation.Inst == inst; + } + + private static RegisterMask GetMask(Register register) + { + Span<long> gprMasks = stackalloc long[4]; + long predMask = 0; + long flagMask = 0; + + switch (register.Type) + { + case RegisterType.Gpr: + gprMasks[register.Index >> 6] = 1L << (register.Index & 0x3f); + break; + case RegisterType.Predicate: + predMask = 1L << register.Index; + break; + case RegisterType.Flag: + flagMask = 1L << register.Index; + break; + } + + return new RegisterMask(gprMasks[0], gprMasks[1], gprMasks[2], gprMasks[3], predMask, flagMask); + } + + private static bool Exchange(RegisterMask[] masks, int blkIndex, RegisterMask value) + { + RegisterMask oldValue = masks[blkIndex]; + + masks[blkIndex] = value; + + return oldValue != value; + } + + private static void LoadLocals(BasicBlock block, RegisterMask masks, List<Register> inArguments) + { + bool fillArgsList = inArguments.Count == 0; + LinkedListNode<INode> node = null; + int argIndex = 0; + + for (int i = 0; i < TotalMasks; i++) + { + (RegisterType regType, int baseRegIndex) = GetRegTypeAndBaseIndex(i); + long mask = masks.GetMask(i); + + while (mask != 0) + { + int bit = BitOperations.TrailingZeroCount(mask); + + mask &= ~(1L << bit); + + Register register = new Register(baseRegIndex + bit, regType); + + if (fillArgsList) + { + inArguments.Add(register); + } + + Operation copyOp = new Operation(Instruction.Copy, OperandHelper.Register(register), OperandHelper.Argument(argIndex++)); + + if (node == null) + { + node = block.Operations.AddFirst(copyOp); + } + else + { + node = block.Operations.AddAfter(node, copyOp); + } + } + } + + Debug.Assert(argIndex <= inArguments.Count); + } + + private static void StoreLocals(BasicBlock block, RegisterMask masks, int inArgumentsCount, List<Register> outArguments) + { + LinkedListNode<INode> node = null; + int argIndex = inArgumentsCount; + bool fillArgsList = outArguments.Count == 0; + + for (int i = 0; i < TotalMasks; i++) + { + (RegisterType regType, int baseRegIndex) = GetRegTypeAndBaseIndex(i); + long mask = masks.GetMask(i); + + while (mask != 0) + { + int bit = BitOperations.TrailingZeroCount(mask); + + mask &= ~(1L << bit); + + Register register = new Register(baseRegIndex + bit, regType); + + if (fillArgsList) + { + outArguments.Add(register); + } + + Operation copyOp = new Operation(Instruction.Copy, OperandHelper.Argument(argIndex++), OperandHelper.Register(register)); + + if (node == null) + { + node = block.Operations.AddBefore(block.Operations.Last, copyOp); + } + else + { + node = block.Operations.AddAfter(node, copyOp); + } + } + } + + Debug.Assert(argIndex <= inArgumentsCount + outArguments.Count); + } + + private static (RegisterType RegType, int BaseRegIndex) GetRegTypeAndBaseIndex(int i) + { + RegisterType regType = RegisterType.Gpr; + int baseRegIndex = 0; + + if (i < GprMasks) + { + baseRegIndex = i * sizeof(long) * 8; + } + else if (i == GprMasks) + { + regType = RegisterType.Predicate; + } + else + { + regType = RegisterType.Flag; + } + + return (regType, baseRegIndex); + } + + private static bool EndsWithReturn(BasicBlock block) + { + if (!(block.GetLastOp() is Operation operation)) + { + return false; + } + + return operation.Inst == Instruction.Return; + } + } +}
\ No newline at end of file diff --git a/Ryujinx.Graphics.Shader/Translation/Translator.cs b/Ryujinx.Graphics.Shader/Translation/Translator.cs index db0924b3..f8093c84 100644 --- a/Ryujinx.Graphics.Shader/Translation/Translator.cs +++ b/Ryujinx.Graphics.Shader/Translation/Translator.cs @@ -14,6 +14,16 @@ namespace Ryujinx.Graphics.Shader.Translation { private const int HeaderSize = 0x50; + private struct FunctionCode + { + public Operation[] Code { get; } + + public FunctionCode(Operation[] code) + { + Code = code; + } + } + public static ShaderProgram Translate(ulong address, IGpuAccessor gpuAccessor, TranslationFlags flags) { return Translate(DecodeShader(address, gpuAccessor, flags, out ShaderConfig config), config); @@ -21,32 +31,65 @@ namespace Ryujinx.Graphics.Shader.Translation public static ShaderProgram Translate(ulong addressA, ulong addressB, IGpuAccessor gpuAccessor, TranslationFlags flags) { - Operation[] opsA = DecodeShader(addressA, gpuAccessor, flags | TranslationFlags.VertexA, out ShaderConfig configA); - Operation[] opsB = DecodeShader(addressB, gpuAccessor, flags, out ShaderConfig config); + FunctionCode[] funcA = DecodeShader(addressA, gpuAccessor, flags | TranslationFlags.VertexA, out ShaderConfig configA); + FunctionCode[] funcB = DecodeShader(addressB, gpuAccessor, flags, out ShaderConfig config); config.SetUsedFeature(configA.UsedFeatures); - return Translate(Combine(opsA, opsB), config, configA.Size); + return Translate(Combine(funcA, funcB), config, configA.Size); } - private static ShaderProgram Translate(Operation[] ops, ShaderConfig config, int sizeA = 0) + private static ShaderProgram Translate(FunctionCode[] functions, ShaderConfig config, int sizeA = 0) { - BasicBlock[] blocks = ControlFlowGraph.MakeCfg(ops); + var cfgs = new ControlFlowGraph[functions.Length]; + var frus = new RegisterUsage.FunctionRegisterUsage[functions.Length]; - if (blocks.Length > 0) + for (int i = 0; i < functions.Length; i++) { - Dominance.FindDominators(blocks[0], blocks.Length); + cfgs[i] = ControlFlowGraph.Create(functions[i].Code); + + if (i != 0) + { + frus[i] = RegisterUsage.RunPass(cfgs[i]); + } + } + + Function[] funcs = new Function[functions.Length]; + + for (int i = 0; i < functions.Length; i++) + { + var cfg = cfgs[i]; + + int inArgumentsCount = 0; + int outArgumentsCount = 0; + + if (i != 0) + { + var fru = frus[i]; + + inArgumentsCount = fru.InArguments.Length; + outArgumentsCount = fru.OutArguments.Length; + } + + if (cfg.Blocks.Length != 0) + { + RegisterUsage.FixupCalls(cfg.Blocks, frus); + + Dominance.FindDominators(cfg); - Dominance.FindDominanceFrontiers(blocks); + Dominance.FindDominanceFrontiers(cfg.Blocks); - Ssa.Rename(blocks); + Ssa.Rename(cfg.Blocks); - Optimizer.RunPass(blocks, config); + Optimizer.RunPass(cfg.Blocks, config); - Lowering.RunPass(blocks, config); + Lowering.RunPass(cfg.Blocks, config); + } + + funcs[i] = new Function(cfg.Blocks, $"fun{i}", false, inArgumentsCount, outArgumentsCount); } - StructuredProgramInfo sInfo = StructuredProgram.MakeStructuredProgram(blocks, config); + StructuredProgramInfo sInfo = StructuredProgram.MakeStructuredProgram(funcs, config); GlslProgram program = GlslGenerator.Generate(sInfo, config); @@ -62,9 +105,9 @@ namespace Ryujinx.Graphics.Shader.Translation return new ShaderProgram(spInfo, config.Stage, glslCode, config.Size, sizeA); } - private static Operation[] DecodeShader(ulong address, IGpuAccessor gpuAccessor, TranslationFlags flags, out ShaderConfig config) + private static FunctionCode[] DecodeShader(ulong address, IGpuAccessor gpuAccessor, TranslationFlags flags, out ShaderConfig config) { - Block[] cfg; + Block[][] cfg; if ((flags & TranslationFlags.Compute) != 0) { @@ -83,112 +126,131 @@ namespace Ryujinx.Graphics.Shader.Translation { gpuAccessor.Log("Invalid branch detected, failed to build CFG."); - return Array.Empty<Operation>(); + return Array.Empty<FunctionCode>(); + } + + Dictionary<ulong, int> funcIds = new Dictionary<ulong, int>(); + + for (int funcIndex = 0; funcIndex < cfg.Length; funcIndex++) + { + funcIds.Add(cfg[funcIndex][0].Address, funcIndex); } - EmitterContext context = new EmitterContext(config); + List<FunctionCode> funcs = new List<FunctionCode>(); ulong maxEndAddress = 0; - for (int blkIndex = 0; blkIndex < cfg.Length; blkIndex++) + for (int funcIndex = 0; funcIndex < cfg.Length; funcIndex++) { - Block block = cfg[blkIndex]; + EmitterContext context = new EmitterContext(config, funcIndex != 0, funcIds); - if (maxEndAddress < block.EndAddress) + for (int blkIndex = 0; blkIndex < cfg[funcIndex].Length; blkIndex++) { - maxEndAddress = block.EndAddress; - } + Block block = cfg[funcIndex][blkIndex]; - context.CurrBlock = block; + if (maxEndAddress < block.EndAddress) + { + maxEndAddress = block.EndAddress; + } - context.MarkLabel(context.GetLabel(block.Address)); + context.CurrBlock = block; - for (int opIndex = 0; opIndex < block.OpCodes.Count; opIndex++) - { - OpCode op = block.OpCodes[opIndex]; + context.MarkLabel(context.GetLabel(block.Address)); - if ((flags & TranslationFlags.DebugMode) != 0) - { - string instName; + EmitOps(context, block); + } - if (op.Emitter != null) - { - instName = op.Emitter.Method.Name; - } - else - { - instName = "???"; + funcs.Add(new FunctionCode(context.GetOperations())); + } - gpuAccessor.Log($"Invalid instruction at 0x{op.Address:X6} (0x{op.RawOpCode:X16})."); - } + config.SizeAdd((int)maxEndAddress + (flags.HasFlag(TranslationFlags.Compute) ? 0 : HeaderSize)); - string dbgComment = $"0x{op.Address:X6}: 0x{op.RawOpCode:X16} {instName}"; + return funcs.ToArray(); + } - context.Add(new CommentNode(dbgComment)); - } + internal static void EmitOps(EmitterContext context, Block block) + { + for (int opIndex = 0; opIndex < block.OpCodes.Count; opIndex++) + { + OpCode op = block.OpCodes[opIndex]; + + if ((context.Config.Flags & TranslationFlags.DebugMode) != 0) + { + string instName; - if (op.NeverExecute) + if (op.Emitter != null) { - continue; + instName = op.Emitter.Method.Name; } + else + { + instName = "???"; - Operand predSkipLbl = null; + context.Config.GpuAccessor.Log($"Invalid instruction at 0x{op.Address:X6} (0x{op.RawOpCode:X16})."); + } - bool skipPredicateCheck = op is OpCodeBranch opBranch && !opBranch.PushTarget; + string dbgComment = $"0x{op.Address:X6}: 0x{op.RawOpCode:X16} {instName}"; - if (op is OpCodeBranchPop opBranchPop) - { - // If the instruction is a SYNC or BRK instruction with only one - // possible target address, then the instruction is basically - // just a simple branch, we can generate code similar to branch - // instructions, with the condition check on the branch itself. - skipPredicateCheck = opBranchPop.Targets.Count < 2; - } + context.Add(new CommentNode(dbgComment)); + } - if (!(op.Predicate.IsPT || skipPredicateCheck)) - { - Operand label; + if (op.NeverExecute) + { + continue; + } - if (opIndex == block.OpCodes.Count - 1 && block.Next != null) - { - label = context.GetLabel(block.Next.Address); - } - else - { - label = Label(); + Operand predSkipLbl = null; - predSkipLbl = label; - } + bool skipPredicateCheck = op is OpCodeBranch opBranch && !opBranch.PushTarget; - Operand pred = Register(op.Predicate); + if (op is OpCodeBranchPop opBranchPop) + { + // If the instruction is a SYNC or BRK instruction with only one + // possible target address, then the instruction is basically + // just a simple branch, we can generate code similar to branch + // instructions, with the condition check on the branch itself. + skipPredicateCheck = opBranchPop.Targets.Count < 2; + } - if (op.InvertPredicate) - { - context.BranchIfTrue(label, pred); - } - else - { - context.BranchIfFalse(label, pred); - } + if (!(op.Predicate.IsPT || skipPredicateCheck)) + { + Operand label; + + if (opIndex == block.OpCodes.Count - 1 && block.Next != null) + { + label = context.GetLabel(block.Next.Address); } + else + { + label = Label(); - context.CurrOp = op; + predSkipLbl = label; + } - op.Emitter?.Invoke(context); + Operand pred = Register(op.Predicate); - if (predSkipLbl != null) + if (op.InvertPredicate) { - context.MarkLabel(predSkipLbl); + context.BranchIfTrue(label, pred); + } + else + { + context.BranchIfFalse(label, pred); } } - } - config.SizeAdd((int)maxEndAddress + (flags.HasFlag(TranslationFlags.Compute) ? 0 : HeaderSize)); + context.CurrOp = op; - return context.GetOperations(); + op.Emitter?.Invoke(context); + + if (predSkipLbl != null) + { + context.MarkLabel(predSkipLbl); + } + } } - private static Operation[] Combine(Operation[] a, Operation[] b) + private static FunctionCode[] Combine(FunctionCode[] a, FunctionCode[] b) { // Here we combine two shaders. // For shader A: @@ -199,15 +261,17 @@ namespace Ryujinx.Graphics.Shader.Translation // For shader B: // - All user attribute loads on shader B are turned into copies from a // temporary variable, as long that attribute is written by shader A. - List<Operation> output = new List<Operation>(a.Length + b.Length); + FunctionCode[] output = new FunctionCode[a.Length + b.Length - 1]; + + List<Operation> ops = new List<Operation>(a.Length + b.Length); Operand[] temps = new Operand[AttributeConsts.UserAttributesCount * 4]; Operand lblB = Label(); - for (int index = 0; index < a.Length; index++) + for (int index = 0; index < a[0].Code.Length; index++) { - Operation operation = a[index]; + Operation operation = a[0].Code[index]; if (IsUserAttribute(operation.Dest)) { @@ -227,19 +291,19 @@ namespace Ryujinx.Graphics.Shader.Translation if (operation.Inst == Instruction.Return) { - output.Add(new Operation(Instruction.Branch, lblB)); + ops.Add(new Operation(Instruction.Branch, lblB)); } else { - output.Add(operation); + ops.Add(operation); } } - output.Add(new Operation(Instruction.MarkLabel, lblB)); + ops.Add(new Operation(Instruction.MarkLabel, lblB)); - for (int index = 0; index < b.Length; index++) + for (int index = 0; index < b[0].Code.Length; index++) { - Operation operation = b[index]; + Operation operation = b[0].Code[index]; for (int srcIndex = 0; srcIndex < operation.SourcesCount; srcIndex++) { @@ -256,10 +320,22 @@ namespace Ryujinx.Graphics.Shader.Translation } } - output.Add(operation); + ops.Add(operation); + } + + output[0] = new FunctionCode(ops.ToArray()); + + for (int i = 1; i < a.Length; i++) + { + output[i] = a[i]; + } + + for (int i = 1; i < b.Length; i++) + { + output[a.Length + i - 1] = b[i]; } - return output.ToArray(); + return output; } private static bool IsUserAttribute(Operand operand) |
