diff options
| author | gdkchan <gab.dark.100@gmail.com> | 2021-02-10 21:54:42 -0300 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-02-11 01:54:42 +0100 |
| commit | c465d771dd099b0ffbb0792b3e74148e01259f19 (patch) | |
| tree | bec6a92ed01427f2300c417ed26de0b9d2361920 /Ryujinx.Graphics.Nvdec.Vp9/DecodeFrame.cs | |
| parent | 172ec326e598971f2251e5acdcfa65faa7291396 (diff) | |
Enable multithreaded VP9 decoding (#2009)
* Enable multithreaded VP9 decoding
* Limit the number of threads used for video decoding
Diffstat (limited to 'Ryujinx.Graphics.Nvdec.Vp9/DecodeFrame.cs')
| -rw-r--r-- | Ryujinx.Graphics.Nvdec.Vp9/DecodeFrame.cs | 180 |
1 files changed, 176 insertions, 4 deletions
diff --git a/Ryujinx.Graphics.Nvdec.Vp9/DecodeFrame.cs b/Ryujinx.Graphics.Nvdec.Vp9/DecodeFrame.cs index 9e267376..2963f7cf 100644 --- a/Ryujinx.Graphics.Nvdec.Vp9/DecodeFrame.cs +++ b/Ryujinx.Graphics.Nvdec.Vp9/DecodeFrame.cs @@ -1,13 +1,14 @@ using Ryujinx.Common.Memory; +using Ryujinx.Graphics.Nvdec.Vp9.Common; +using Ryujinx.Graphics.Nvdec.Vp9.Dsp; +using Ryujinx.Graphics.Nvdec.Vp9.Types; +using Ryujinx.Graphics.Video; using System; using System.Buffers.Binary; using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; -using Ryujinx.Graphics.Nvdec.Vp9.Common; -using Ryujinx.Graphics.Nvdec.Vp9.Dsp; -using Ryujinx.Graphics.Nvdec.Vp9.Types; -using Ryujinx.Graphics.Video; +using System.Threading.Tasks; using Mv = Ryujinx.Graphics.Nvdec.Vp9.Types.Mv; namespace Ryujinx.Graphics.Nvdec.Vp9 @@ -1095,6 +1096,19 @@ namespace Ryujinx.Graphics.Nvdec.Vp9 data = data.Slice(size); } + private static void GetTileBuffers(ref Vp9Common cm, ArrayPtr<byte> data, int tileCols, ref Array64<TileBuffer> tileBuffers) + { + int c; + + for (c = 0; c < tileCols; ++c) + { + bool isLast = c == tileCols - 1; + ref TileBuffer buf = ref tileBuffers[c]; + buf.Col = c; + GetTileBuffer(isLast, ref cm.Error, ref data, ref buf); + } + } + private static void GetTileBuffers( ref Vp9Common cm, ArrayPtr<byte> data, @@ -1181,5 +1195,163 @@ namespace Ryujinx.Graphics.Nvdec.Vp9 // Get last tile data. return cm.TileWorkerData[tileCols * tileRows - 1].BitReader.FindEnd(); } + + private static bool DecodeTileCol(ref TileWorkerData tileData, ref Vp9Common cm, ref Array64<TileBuffer> tileBuffers) + { + ref TileInfo tile = ref tileData.Xd.Tile; + int finalCol = (1 << cm.Log2TileCols) - 1; + ArrayPtr<byte> bitReaderEnd = ArrayPtr<byte>.Null; + + int n = tileData.BufStart; + + tileData.Xd.Corrupted = false; + + do + { + ref TileBuffer buf = ref tileBuffers[n]; + + Debug.Assert(cm.Log2TileRows == 0); + tileData.Dqcoeff = new Array32<Array32<int>>(); + tile.Init(ref cm, 0, buf.Col); + SetupTokenDecoder(buf.Data, buf.Size, ref tileData.ErrorInfo, ref tileData.BitReader); + cm.InitMacroBlockD(ref tileData.Xd, new ArrayPtr<int>(ref tileData.Dqcoeff[0][0], 32 * 32)); + tileData.Xd.ErrorInfo = new Ptr<InternalErrorInfo>(ref tileData.ErrorInfo); + + for (int miRow = tile.MiRowStart; miRow < tile.MiRowEnd; miRow += Constants.MiBlockSize) + { + tileData.Xd.LeftContext = new Array3<Array16<sbyte>>(); + tileData.Xd.LeftSegContext = new Array8<sbyte>(); + for (int miCol = tile.MiColStart; miCol < tile.MiColEnd; miCol += Constants.MiBlockSize) + { + DecodePartition(ref tileData, ref cm, miRow, miCol, BlockSize.Block64x64, 4); + } + } + + if (buf.Col == finalCol) + { + bitReaderEnd = tileData.BitReader.FindEnd(); + } + } while (!tileData.Xd.Corrupted && ++n <= tileData.BufEnd); + + tileData.DataEnd = bitReaderEnd; + return !tileData.Xd.Corrupted; + } + + public static unsafe ArrayPtr<byte> DecodeTilesMt(ref Vp9Common cm, ArrayPtr<byte> data, int maxThreads) + { + ArrayPtr<byte> bitReaderEnd = ArrayPtr<byte>.Null; + + int tileCols = 1 << cm.Log2TileCols; + int tileRows = 1 << cm.Log2TileRows; + int totalTiles = tileCols * tileRows; + int numWorkers = Math.Min(maxThreads, tileCols); + int n; + + Debug.Assert(tileCols <= (1 << 6)); + Debug.Assert(tileRows == 1); + + cm.AboveContext.ToSpan().Fill(0); + cm.AboveSegContext.ToSpan().Fill(0); + + for (n = 0; n < numWorkers; ++n) + { + ref TileWorkerData tileData = ref cm.TileWorkerData[n + totalTiles]; + + tileData.Xd = cm.Mb; + tileData.Xd.Counts = new Ptr<Vp9BackwardUpdates>(ref tileData.Counts); + tileData.Counts = new Vp9BackwardUpdates(); + } + + Array64<TileBuffer> tileBuffers = new Array64<TileBuffer>(); + + GetTileBuffers(ref cm, data, tileCols, ref tileBuffers); + + tileBuffers.ToSpan().Slice(0, tileCols).Sort(CompareTileBuffers); + + if (numWorkers == tileCols) + { + TileBuffer largest = tileBuffers[0]; + Span<TileBuffer> buffers = tileBuffers.ToSpan(); + buffers.Slice(1).CopyTo(buffers.Slice(0, tileBuffers.Length - 1)); + tileBuffers[tileCols - 1] = largest; + } + else + { + int start = 0, end = tileCols - 2; + TileBuffer tmp; + + // Interleave the tiles to distribute the load between threads, assuming a + // larger tile implies it is more difficult to decode. + while (start < end) + { + tmp = tileBuffers[start]; + tileBuffers[start] = tileBuffers[end]; + tileBuffers[end] = tmp; + start += 2; + end -= 2; + } + } + + int baseVal = tileCols / numWorkers; + int remain = tileCols % numWorkers; + int bufStart = 0; + + for (n = 0; n < numWorkers; ++n) + { + int count = baseVal + (remain + n) / numWorkers; + ref TileWorkerData tileData = ref cm.TileWorkerData[n + totalTiles]; + + tileData.BufStart = bufStart; + tileData.BufEnd = bufStart + count - 1; + tileData.DataEnd = data.Slice(data.Length); + bufStart += count; + } + + Ptr<Vp9Common> cmPtr = new Ptr<Vp9Common>(ref cm); + + Parallel.For(0, numWorkers, (n) => + { + ref TileWorkerData tileData = ref cmPtr.Value.TileWorkerData[n + totalTiles]; + + if (!DecodeTileCol(ref tileData, ref cmPtr.Value, ref tileBuffers)) + { + cmPtr.Value.Mb.Corrupted = true; + } + }); + + for (; n > 0; --n) + { + if (bitReaderEnd.IsNull) + { + ref TileWorkerData tileData = ref cm.TileWorkerData[n - 1 + totalTiles]; + bitReaderEnd = tileData.DataEnd; + } + } + + for (n = 0; n < numWorkers; ++n) + { + ref TileWorkerData tileData = ref cm.TileWorkerData[n + totalTiles]; + AccumulateFrameCounts(ref cm.Counts.Value, ref tileData.Counts); + } + + Debug.Assert(!bitReaderEnd.IsNull || cm.Mb.Corrupted); + return bitReaderEnd; + } + + private static int CompareTileBuffers(TileBuffer bufA, TileBuffer bufB) + { + return (bufA.Size < bufB.Size ? 1 : 0) - (bufA.Size > bufB.Size ? 1 : 0); + } + + private static void AccumulateFrameCounts(ref Vp9BackwardUpdates accum, ref Vp9BackwardUpdates counts) + { + Span<uint> a = MemoryMarshal.Cast<Vp9BackwardUpdates, uint>(MemoryMarshal.CreateSpan(ref accum, 1)); + Span<uint> c = MemoryMarshal.Cast<Vp9BackwardUpdates, uint>(MemoryMarshal.CreateSpan(ref counts, 1)); + + for (int i = 0; i < a.Length; i++) + { + a[i] += c[i]; + } + } } } |
