diff options
| author | gdkchan <gab.dark.100@gmail.com> | 2022-10-29 13:45:30 -0300 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-29 13:45:30 -0300 |
| commit | 59cdf310bdc16d537ba5ff3813399c54abbce2b7 (patch) | |
| tree | 4cd095f5b7dc3a18f53c9c94c9001a5f3fd5c268 /Ryujinx.Graphics.Shader/CodeGen | |
| parent | 4e34170a84fc1b2096ad4588dec9460a5f8c9870 (diff) | |
SPIR-V: Fix tessellation control shader output types (#3807)
* SPIR-V: Fix tessellation control shader output types
* Shader cache version bump
Diffstat (limited to 'Ryujinx.Graphics.Shader/CodeGen')
| -rw-r--r-- | Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs | 59 | ||||
| -rw-r--r-- | Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs | 15 |
2 files changed, 70 insertions, 4 deletions
diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs index fe5e11f4..04c05325 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs @@ -262,6 +262,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv Instruction ioVariable, elemIndex; + Instruction invocationId = null; + + if (Config.Stage == ShaderStage.TessellationControl && isOutAttr) + { + invocationId = Load(TypeS32(), Inputs[AttributeConsts.InvocationId]); + } + bool isUserAttr = attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd; if (isUserAttr && @@ -273,7 +280,17 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv elemIndex = Constant(TypeU32(), attrInfo.GetInnermostIndex()); var vecIndex = Constant(TypeU32(), (attr - AttributeConsts.UserAttributeBase) >> 4); - if (AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr)) + bool isArray = AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr); + + if (invocationId != null && isArray) + { + return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId, index, vecIndex, elemIndex); + } + else if (invocationId != null) + { + return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId, vecIndex, elemIndex); + } + else if (isArray) { return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, vecIndex, elemIndex); } @@ -308,12 +325,29 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv if ((type & (AggregateType.Array | AggregateType.Vector)) == 0) { - return isIndexed ? AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index) : ioVariable; + if (invocationId != null) + { + return isIndexed + ? AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId, index) + : AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId); + } + else + { + return isIndexed ? AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index) : ioVariable; + } } elemIndex = Constant(TypeU32(), attrInfo.GetInnermostIndex()); - if (isIndexed) + if (invocationId != null && isIndexed) + { + return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId, index, elemIndex); + } + else if (invocationId != null) + { + return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId, elemIndex); + } + else if (isIndexed) { return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, elemIndex); } @@ -327,12 +361,29 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv { var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input; + Instruction invocationId = null; + + if (Config.Stage == ShaderStage.TessellationControl && isOutAttr) + { + invocationId = Load(TypeS32(), Inputs[AttributeConsts.InvocationId]); + } + elemType = AggregateType.FP32; var ioVariable = isOutAttr ? OutputsArray : InputsArray; var vecIndex = ShiftRightLogical(TypeS32(), attrIndex, Constant(TypeS32(), 2)); var elemIndex = BitwiseAnd(TypeS32(), attrIndex, Constant(TypeS32(), 3)); - if (AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr)) + bool isArray = AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr); + + if (invocationId != null && isArray) + { + return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId, index, vecIndex, elemIndex); + } + else if (invocationId != null) + { + return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId, vecIndex, elemIndex); + } + else if (isArray) { return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, vecIndex, elemIndex); } diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs index 1a4decf5..9f8dd7df 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs @@ -473,6 +473,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var attrType = context.TypeVector(context.TypeFP32(), (LiteralInteger)4); attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), (LiteralInteger)MaxAttributes)); + if (context.Config.Stage == ShaderStage.TessellationControl) + { + attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), context.Config.ThreadsPerInputPrimitive)); + } + var spvType = context.TypePointer(StorageClass.Output, attrType); var spvVar = context.Variable(spvType, StorageClass.Output); @@ -543,6 +548,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv } } + if (context.Config.Stage == ShaderStage.TessellationControl && isOutAttr && !perPatch) + { + attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), context.Config.ThreadsPerInputPrimitive)); + } + var spvType = context.TypePointer(storageClass, attrType); var spvVar = context.Variable(spvType, storageClass); @@ -634,6 +644,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), (LiteralInteger)arraySize)); } + if (context.Config.Stage == ShaderStage.TessellationControl && isOutAttr) + { + attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), context.Config.ThreadsPerInputPrimitive)); + } + var spvType = context.TypePointer(storageClass, attrType); var spvVar = context.Variable(spvType, storageClass); |
