diff options
Diffstat (limited to 'Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs')
| -rw-r--r-- | Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs | 247 |
1 files changed, 247 insertions, 0 deletions
diff --git a/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs b/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs new file mode 100644 index 00000000..36c8b51a --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs @@ -0,0 +1,247 @@ +using Ryujinx.HLE.HOS.Services.Sockets.Bsd; +using Ryujinx.HLE.HOS.Services.Ssl.Types; +using System; +using System.IO; +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Authentication; + +namespace Ryujinx.HLE.HOS.Services.Ssl.SslService +{ + class SslManagedSocketConnection : ISslConnectionBase + { + public int SocketFd { get; } + + public ISocket Socket { get; } + + private BsdContext _bsdContext; + private SslVersion _sslVersion; + private SslStream _stream; + private bool _isBlockingSocket; + private int _previousReadTimeout; + + public SslManagedSocketConnection(BsdContext bsdContext, SslVersion sslVersion, int socketFd, ISocket socket) + { + _bsdContext = bsdContext; + _sslVersion = sslVersion; + + SocketFd = socketFd; + Socket = socket; + } + + private void StartSslOperation() + { + // Save blocking state + _isBlockingSocket = Socket.Blocking; + + // Force blocking for SslStream + Socket.Blocking = true; + } + + private void EndSslOperation() + { + // Restore blocking state + Socket.Blocking = _isBlockingSocket; + } + + private void StartSslReadOperation() + { + StartSslOperation(); + + if (!_isBlockingSocket) + { + _previousReadTimeout = _stream.ReadTimeout; + + _stream.ReadTimeout = 1; + } + } + + private void EndSslReadOperation() + { + if (!_isBlockingSocket) + { + _stream.ReadTimeout = _previousReadTimeout; + } + + EndSslOperation(); + } + + private static SslProtocols TranslateSslVersion(SslVersion version) + { + switch (version & SslVersion.VersionMask) + { + case SslVersion.Auto: + return SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13; + case SslVersion.TlsV10: + return SslProtocols.Tls; + case SslVersion.TlsV11: + return SslProtocols.Tls11; + case SslVersion.TlsV12: + return SslProtocols.Tls12; + case SslVersion.TlsV13: + return SslProtocols.Tls13; + default: + throw new NotImplementedException(version.ToString()); + } + } + + public ResultCode Handshake(string hostName) + { + StartSslOperation(); + _stream = new SslStream(new NetworkStream(((ManagedSocket)Socket).Socket, false), false, null, null); + _stream.AuthenticateAsClient(hostName, null, TranslateSslVersion(_sslVersion), false); + EndSslOperation(); + + return ResultCode.Success; + } + + public ResultCode Peek(out int peekCount, Memory<byte> buffer) + { + // NOTE: We cannot support that on .NET SSL API. + // As Nintendo's curl implementation detail check if a connection is alive via Peek, we just return that it would block to let it know that it's alive. + peekCount = -1; + + return ResultCode.WouldBlock; + } + + public int Pending() + { + // Unsupported + return 0; + } + + private static bool TryTranslateWinSockError(bool isBlocking, WsaError error, out ResultCode resultCode) + { + switch (error) + { + case WsaError.WSAETIMEDOUT: + resultCode = isBlocking ? ResultCode.Timeout : ResultCode.WouldBlock; + return true; + case WsaError.WSAECONNABORTED: + resultCode = ResultCode.ConnectionAbort; + return true; + case WsaError.WSAECONNRESET: + resultCode = ResultCode.ConnectionReset; + return true; + default: + resultCode = ResultCode.Success; + return false; + } + } + + public ResultCode Read(out int readCount, Memory<byte> buffer) + { + if (!Socket.Poll(0, SelectMode.SelectRead)) + { + readCount = -1; + + return ResultCode.WouldBlock; + } + + StartSslReadOperation(); + + try + { + readCount = _stream.Read(buffer.Span); + } + catch (IOException exception) + { + readCount = -1; + + if (exception.InnerException is SocketException socketException) + { + WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode; + + if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result)) + { + return result; + } + else + { + throw socketException; + } + } + else + { + throw exception; + } + } + finally + { + EndSslReadOperation(); + } + + return ResultCode.Success; + } + + public ResultCode Write(out int writtenCount, ReadOnlyMemory<byte> buffer) + { + if (!Socket.Poll(0, SelectMode.SelectWrite)) + { + writtenCount = 0; + + return ResultCode.WouldBlock; + } + + StartSslOperation(); + + try + { + _stream.Write(buffer.Span); + } + catch (IOException exception) + { + writtenCount = -1; + + if (exception.InnerException is SocketException socketException) + { + WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode; + + if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result)) + { + return result; + } + else + { + throw socketException; + } + } + else + { + throw exception; + } + } + finally + { + EndSslOperation(); + } + + // .NET API doesn't provide the size written, assume all written. + writtenCount = buffer.Length; + + return ResultCode.Success; + } + + public ResultCode GetServerCertificate(string hostname, Span<byte> certificates, out uint storageSize, out uint certificateCount) + { + byte[] rawCertData = _stream.RemoteCertificate.GetRawCertData(); + + storageSize = (uint)rawCertData.Length; + certificateCount = 1; + + if (rawCertData.Length > certificates.Length) + { + return ResultCode.CertBufferTooSmall; + } + + rawCertData.CopyTo(certificates); + + return ResultCode.Success; + } + + public void Dispose() + { + _bsdContext.CloseFileDescriptor(SocketFd); + } + } +} |
