aboutsummaryrefslogtreecommitdiff
path: root/Ryujinx.HLE/HOS/Services/Ssl/SslService
diff options
context:
space:
mode:
Diffstat (limited to 'Ryujinx.HLE/HOS/Services/Ssl/SslService')
-rw-r--r--Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs454
-rw-r--r--Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnectionBase.cs25
-rw-r--r--Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslContext.cs11
-rw-r--r--Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs247
4 files changed, 707 insertions, 30 deletions
diff --git a/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs b/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs
index 24f3d066..fba22f45 100644
--- a/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs
+++ b/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs
@@ -1,41 +1,101 @@
using Ryujinx.Common.Logging;
+using Ryujinx.HLE.Exceptions;
+using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using Ryujinx.HLE.HOS.Services.Ssl.Types;
+using Ryujinx.Memory;
+using System;
using System.Text;
namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
{
- class ISslConnection : IpcService
+ class ISslConnection : IpcService, IDisposable
{
- public ISslConnection() { }
+ private bool _doNotClockSocket;
+ private bool _getServerCertChain;
+ private bool _skipDefaultVerify;
+ private bool _enableAlpn;
+
+ private SslVersion _sslVersion;
+ private IoMode _ioMode;
+ private VerifyOption _verifyOption;
+ private SessionCacheMode _sessionCacheMode;
+ private string _hostName;
+
+ private ISslConnectionBase _connection;
+ private BsdContext _bsdContext;
+ private readonly long _processId;
+
+ private byte[] _nextAplnProto;
+
+ public ISslConnection(long processId, SslVersion sslVersion)
+ {
+ _processId = processId;
+ _sslVersion = sslVersion;
+ _ioMode = IoMode.Blocking;
+ _sessionCacheMode = SessionCacheMode.None;
+ _verifyOption = VerifyOption.PeerCa | VerifyOption.HostName;
+ }
[CommandHipc(0)]
// SetSocketDescriptor(u32) -> u32
public ResultCode SetSocketDescriptor(ServiceCtx context)
{
- uint socketFd = context.RequestData.ReadUInt32();
- uint duplicateSocketFd = 0;
+ if (_connection != null)
+ {
+ return ResultCode.AlreadyInUse;
+ }
+
+ _bsdContext = BsdContext.GetContext(_processId);
+
+ if (_bsdContext == null)
+ {
+ return ResultCode.InvalidSocket;
+ }
+
+ int inputFd = context.RequestData.ReadInt32();
- context.ResponseData.Write(duplicateSocketFd);
+ int internalFd = _bsdContext.DuplicateFileDescriptor(inputFd);
- Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { socketFd });
+ if (internalFd == -1)
+ {
+ return ResultCode.InvalidSocket;
+ }
+
+ InitializeConnection(internalFd);
+
+ int outputFd = inputFd;
+
+ if (_doNotClockSocket)
+ {
+ outputFd = -1;
+ }
+
+ context.ResponseData.Write(outputFd);
return ResultCode.Success;
}
+ private void InitializeConnection(int socketFd)
+ {
+ ISocket bsdSocket = _bsdContext.RetrieveSocket(socketFd);
+
+ _connection = new SslManagedSocketConnection(_bsdContext, _sslVersion, socketFd, bsdSocket);
+ }
+
[CommandHipc(1)]
// SetHostName(buffer<bytes, 5>)
public ResultCode SetHostName(ServiceCtx context)
{
ulong hostNameDataPosition = context.Request.SendBuff[0].Position;
- ulong hostNameDataSize = context.Request.SendBuff[0].Size;
+ ulong hostNameDataSize = context.Request.SendBuff[0].Size;
byte[] hostNameData = new byte[hostNameDataSize];
context.Memory.Read(hostNameDataPosition, hostNameData);
- string hostName = Encoding.ASCII.GetString(hostNameData).Trim('\0');
+ _hostName = Encoding.ASCII.GetString(hostNameData).Trim('\0');
- Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { hostName });
+ Logger.Info?.Print(LogClass.ServiceSsl, _hostName);
return ResultCode.Success;
}
@@ -44,9 +104,9 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
// SetVerifyOption(nn::ssl::sf::VerifyOption)
public ResultCode SetVerifyOption(ServiceCtx context)
{
- VerifyOption verifyOption = (VerifyOption)context.RequestData.ReadUInt32();
+ _verifyOption = (VerifyOption)context.RequestData.ReadUInt32();
- Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { verifyOption });
+ Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _verifyOption });
return ResultCode.Success;
}
@@ -55,9 +115,67 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
// SetIoMode(nn::ssl::sf::IoMode)
public ResultCode SetIoMode(ServiceCtx context)
{
- IoMode ioMode = (IoMode)context.RequestData.ReadUInt32();
+ if (_connection == null)
+ {
+ return ResultCode.NoSocket;
+ }
+
+ _ioMode = (IoMode)context.RequestData.ReadUInt32();
+
+ _connection.Socket.Blocking = _ioMode == IoMode.Blocking;
+
+ Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _ioMode });
+
+ return ResultCode.Success;
+ }
+
+ [CommandHipc(4)]
+ // GetSocketDescriptor() -> u32
+ public ResultCode GetSocketDescriptor(ServiceCtx context)
+ {
+ context.ResponseData.Write(_connection.SocketFd);
+
+ return ResultCode.Success;
+ }
+
+ [CommandHipc(5)]
+ // GetHostName(buffer<bytes, 6>) -> u32
+ public ResultCode GetHostName(ServiceCtx context)
+ {
+ ulong hostNameDataPosition = context.Request.ReceiveBuff[0].Position;
+ ulong hostNameDataSize = context.Request.ReceiveBuff[0].Size;
+
+ byte[] hostNameData = new byte[hostNameDataSize];
- Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { ioMode });
+ Encoding.ASCII.GetBytes(_hostName, hostNameData);
+
+ context.Memory.Write(hostNameDataPosition, hostNameData);
+
+ context.ResponseData.Write((uint)_hostName.Length);
+
+ Logger.Info?.Print(LogClass.ServiceSsl, _hostName);
+
+ return ResultCode.Success;
+ }
+
+ [CommandHipc(6)]
+ // GetVerifyOption() -> nn::ssl::sf::VerifyOption
+ public ResultCode GetVerifyOption(ServiceCtx context)
+ {
+ context.ResponseData.Write((uint)_verifyOption);
+
+ Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _verifyOption });
+
+ return ResultCode.Success;
+ }
+
+ [CommandHipc(7)]
+ // GetIoMode() -> nn::ssl::sf::IoMode
+ public ResultCode GetIoMode(ServiceCtx context)
+ {
+ context.ResponseData.Write((uint)_ioMode);
+
+ Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _ioMode });
return ResultCode.Success;
}
@@ -66,32 +184,155 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
// DoHandshake()
public ResultCode DoHandshake(ServiceCtx context)
{
- Logger.Stub?.PrintStub(LogClass.ServiceSsl);
+ if (_connection == null)
+ {
+ return ResultCode.NoSocket;
+ }
- return ResultCode.Success;
+ return _connection.Handshake(_hostName);
+ }
+
+ [CommandHipc(9)]
+ // DoHandshakeGetServerCert() -> (u32, u32, buffer<bytes, 6>)
+ public ResultCode DoHandshakeGetServerCert(ServiceCtx context)
+ {
+ if (_connection == null)
+ {
+ return ResultCode.NoSocket;
+ }
+
+ ResultCode result = _connection.Handshake(_hostName);
+
+ if (result == ResultCode.Success)
+ {
+ if (_getServerCertChain)
+ {
+ using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size))
+ {
+ result = _connection.GetServerCertificate(_hostName, region.Memory.Span, out uint bufferSize, out uint certificateCount);
+
+ context.ResponseData.Write(bufferSize);
+ context.ResponseData.Write(certificateCount);
+ }
+ }
+ else
+ {
+ context.ResponseData.Write(0);
+ context.ResponseData.Write(0);
+ }
+ }
+
+ return result;
+ }
+
+ [CommandHipc(10)]
+ // Read() -> (u32, buffer<bytes, 6>)
+ public ResultCode Read(ServiceCtx context)
+ {
+ if (_connection == null)
+ {
+ return ResultCode.NoSocket;
+ }
+
+ ResultCode result;
+
+ using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size))
+ {
+ // TODO: Better error management.
+ result = _connection.Read(out int readCount, region.Memory);
+
+ if (result == ResultCode.Success)
+ {
+ context.ResponseData.Write(readCount);
+ }
+ }
+
+ return result;
}
[CommandHipc(11)]
- // Write(buffer<bytes, 5>) -> u32
+ // Write(buffer<bytes, 5>) -> s32
public ResultCode Write(ServiceCtx context)
{
- ulong inputDataPosition = context.Request.SendBuff[0].Position;
- ulong inputDataSize = context.Request.SendBuff[0].Size;
+ if (_connection == null)
+ {
+ return ResultCode.NoSocket;
+ }
- byte[] data = new byte[inputDataSize];
+ // We don't dispose as this isn't supposed to be modified
+ WritableRegion region = context.Memory.GetWritableRegion(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size);
- context.Memory.Read(inputDataPosition, data);
+ // TODO: Better error management.
+ ResultCode result = _connection.Write(out int writtenCount, region.Memory);
- // NOTE: Tell the guest everything is transferred.
- uint transferredSize = (uint)inputDataSize;
+ if (result == ResultCode.Success)
+ {
+ context.ResponseData.Write(writtenCount);
+ }
- context.ResponseData.Write(transferredSize);
+ return result;
+ }
+
+ [CommandHipc(12)]
+ // Pending() -> s32
+ public ResultCode Pending(ServiceCtx context)
+ {
+ if (_connection == null)
+ {
+ return ResultCode.NoSocket;
+ }
- Logger.Stub?.PrintStub(LogClass.ServiceSsl);
+ context.ResponseData.Write(_connection.Pending());
return ResultCode.Success;
}
+ [CommandHipc(13)]
+ // Peek() -> (s32, buffer<bytes, 6>)
+ public ResultCode Peek(ServiceCtx context)
+ {
+ if (_connection == null)
+ {
+ return ResultCode.NoSocket;
+ }
+
+ ResultCode result;
+
+ using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size))
+ {
+ // TODO: Better error management.
+ result = _connection.Peek(out int peekCount, region.Memory);
+
+ if (result == ResultCode.Success)
+ {
+ context.ResponseData.Write(peekCount);
+ }
+ }
+
+ return result;
+ }
+
+ [CommandHipc(14)]
+ // Poll(nn::ssl::sf::PollEvent poll_event, u32 timeout) -> nn::ssl::sf::PollEvent
+ public ResultCode Poll(ServiceCtx context)
+ {
+ throw new ServiceNotImplementedException(this, context);
+ }
+
+ [CommandHipc(15)]
+ // GetVerifyCertError()
+ public ResultCode GetVerifyCertError(ServiceCtx context)
+ {
+ throw new ServiceNotImplementedException(this, context);
+ }
+
+ [CommandHipc(16)]
+ // GetNeededServerCertBufferSize() -> u32
+ public ResultCode GetNeededServerCertBufferSize(ServiceCtx context)
+ {
+ throw new ServiceNotImplementedException(this, context);
+ }
+
[CommandHipc(17)]
// SetSessionCacheMode(nn::ssl::sf::SessionCacheMode)
public ResultCode SetSessionCacheMode(ServiceCtx context)
@@ -100,19 +341,176 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { sessionCacheMode });
+ _sessionCacheMode = sessionCacheMode;
+
return ResultCode.Success;
}
+ [CommandHipc(18)]
+ // GetSessionCacheMode() -> nn::ssl::sf::SessionCacheMode
+ public ResultCode GetSessionCacheMode(ServiceCtx context)
+ {
+ throw new ServiceNotImplementedException(this, context);
+ }
+
+ [CommandHipc(19)]
+ // FlushSessionCache()
+ public ResultCode FlushSessionCache(ServiceCtx context)
+ {
+ throw new ServiceNotImplementedException(this, context);
+ }
+
+ [CommandHipc(20)]
+ // SetRenegotiationMode(nn::ssl::sf::RenegotiationMode)
+ public ResultCode SetRenegotiationMode(ServiceCtx context)
+ {
+ throw new ServiceNotImplementedException(this, context);
+ }
+
+ [CommandHipc(21)]
+ // GetRenegotiationMode() -> nn::ssl::sf::RenegotiationMode
+ public ResultCode GetRenegotiationMode(ServiceCtx context)
+ {
+ throw new ServiceNotImplementedException(this, context);
+ }
+
[CommandHipc(22)]
- // SetOption(b8, nn::ssl::sf::OptionType)
+ // SetOption(b8 value, nn::ssl::sf::OptionType option)
public ResultCode SetOption(ServiceCtx context)
{
- bool optionEnabled = context.RequestData.ReadBoolean();
- OptionType optionType = (OptionType)context.RequestData.ReadUInt32();
+ bool value = context.RequestData.ReadUInt32() != 0;
+ OptionType option = (OptionType)context.RequestData.ReadUInt32();
+
+ Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { option, value });
+
+ return SetOption(option, value);
+ }
+
+ [CommandHipc(23)]
+ // GetOption(nn::ssl::sf::OptionType) -> b8
+ public ResultCode GetOption(ServiceCtx context)
+ {
+ OptionType option = (OptionType)context.RequestData.ReadUInt32();
+
+ Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { option });
+
+ ResultCode result = GetOption(option, out bool value);
+
+ if (result == ResultCode.Success)
+ {
+ context.ResponseData.Write(value);
+ }
+
+ return result;
+ }
+
+ [CommandHipc(24)]
+ // GetVerifyCertErrors() -> (u32, u32, buffer<bytes, 6>)
+ public ResultCode GetVerifyCertErrors(ServiceCtx context)
+ {
+ throw new ServiceNotImplementedException(this, context);
+ }
- Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { optionType, optionEnabled });
+ [CommandHipc(25)] // 4.0.0+
+ // GetCipherInfo(u32) -> buffer<bytes, 6>
+ public ResultCode GetCipherInfo(ServiceCtx context)
+ {
+ throw new ServiceNotImplementedException(this, context);
+ }
+
+ [CommandHipc(26)]
+ // SetNextAlpnProto(buffer<bytes, 5>) -> u32
+ public ResultCode SetNextAlpnProto(ServiceCtx context)
+ {
+ ulong inputDataPosition = context.Request.SendBuff[0].Position;
+ ulong inputDataSize = context.Request.SendBuff[0].Size;
+
+ _nextAplnProto = new byte[inputDataSize];
+
+ context.Memory.Read(inputDataPosition, _nextAplnProto);
+
+ Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { inputDataSize });
+
+ return ResultCode.Success;
+ }
+
+ [CommandHipc(27)]
+ // GetNextAlpnProto(buffer<bytes, 6>) -> u32
+ public ResultCode GetNextAlpnProto(ServiceCtx context)
+ {
+ ulong outputDataPosition = context.Request.ReceiveBuff[0].Position;
+ ulong outputDataSize = context.Request.ReceiveBuff[0].Size;
+
+ context.Memory.Write(outputDataPosition, _nextAplnProto);
+
+ context.ResponseData.Write(_nextAplnProto.Length);
+
+ Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { outputDataSize });
return ResultCode.Success;
}
+
+ private ResultCode SetOption(OptionType option, bool value)
+ {
+ switch (option)
+ {
+ case OptionType.DoNotCloseSocket:
+ _doNotClockSocket = value;
+ break;
+
+ case OptionType.GetServerCertChain:
+ _getServerCertChain = value;
+ break;
+
+ case OptionType.SkipDefaultVerify:
+ _skipDefaultVerify = value;
+ break;
+
+ case OptionType.EnableAlpn:
+ _enableAlpn = value;
+ break;
+
+ default:
+ Logger.Warning?.Print(LogClass.ServiceSsl, $"Unsupported option {option}");
+ return ResultCode.InvalidOption;
+ }
+
+ return ResultCode.Success;
+ }
+
+ private ResultCode GetOption(OptionType option, out bool value)
+ {
+ switch (option)
+ {
+ case OptionType.DoNotCloseSocket:
+ value = _doNotClockSocket;
+ break;
+
+ case OptionType.GetServerCertChain:
+ value = _getServerCertChain;
+ break;
+
+ case OptionType.SkipDefaultVerify:
+ value = _skipDefaultVerify;
+ break;
+
+ case OptionType.EnableAlpn:
+ value = _enableAlpn;
+ break;
+
+ default:
+ Logger.Warning?.Print(LogClass.ServiceSsl, $"Unsupported option {option}");
+
+ value = false;
+ return ResultCode.InvalidOption;
+ }
+
+ return ResultCode.Success;
+ }
+
+ public void Dispose()
+ {
+ _connection?.Dispose();
+ }
}
} \ No newline at end of file
diff --git a/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnectionBase.cs b/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnectionBase.cs
new file mode 100644
index 00000000..74e5fcda
--- /dev/null
+++ b/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnectionBase.cs
@@ -0,0 +1,25 @@
+using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
+using System;
+using System.Net.Sockets;
+
+namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
+{
+ interface ISslConnectionBase: IDisposable
+ {
+ int SocketFd { get; }
+
+ ISocket Socket { get; }
+
+ ResultCode Handshake(string hostName);
+
+ ResultCode GetServerCertificate(string hostname, Span<byte> certificates, out uint storageSize, out uint certificateCount);
+
+ ResultCode Write(out int writtenCount, ReadOnlyMemory<byte> buffer);
+
+ ResultCode Read(out int readCount, Memory<byte> buffer);
+
+ ResultCode Peek(out int peekCount, Memory<byte> buffer);
+
+ int Pending();
+ }
+}
diff --git a/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslContext.cs b/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslContext.cs
index 718af2cb..0b8cb463 100644
--- a/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslContext.cs
+++ b/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslContext.cs
@@ -1,4 +1,5 @@
using Ryujinx.Common.Logging;
+using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using Ryujinx.HLE.HOS.Services.Ssl.Types;
using System.Text;
@@ -8,16 +9,22 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
{
private uint _connectionCount;
+ private readonly long _processId;
+ private readonly SslVersion _sslVersion;
private ulong _serverCertificateId;
private ulong _clientCertificateId;
- public ISslContext(ServiceCtx context) { }
+ public ISslContext(long processId, SslVersion sslVersion)
+ {
+ _processId = processId;
+ _sslVersion = sslVersion;
+ }
[CommandHipc(2)]
// CreateConnection() -> object<nn::ssl::sf::ISslConnection>
public ResultCode CreateConnection(ServiceCtx context)
{
- MakeObject(context, new ISslConnection());
+ MakeObject(context, new ISslConnection(_processId, _sslVersion));
_connectionCount++;
diff --git a/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs b/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs
new file mode 100644
index 00000000..36c8b51a
--- /dev/null
+++ b/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs
@@ -0,0 +1,247 @@
+using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
+using Ryujinx.HLE.HOS.Services.Ssl.Types;
+using System;
+using System.IO;
+using System.Net.Security;
+using System.Net.Sockets;
+using System.Security.Authentication;
+
+namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
+{
+ class SslManagedSocketConnection : ISslConnectionBase
+ {
+ public int SocketFd { get; }
+
+ public ISocket Socket { get; }
+
+ private BsdContext _bsdContext;
+ private SslVersion _sslVersion;
+ private SslStream _stream;
+ private bool _isBlockingSocket;
+ private int _previousReadTimeout;
+
+ public SslManagedSocketConnection(BsdContext bsdContext, SslVersion sslVersion, int socketFd, ISocket socket)
+ {
+ _bsdContext = bsdContext;
+ _sslVersion = sslVersion;
+
+ SocketFd = socketFd;
+ Socket = socket;
+ }
+
+ private void StartSslOperation()
+ {
+ // Save blocking state
+ _isBlockingSocket = Socket.Blocking;
+
+ // Force blocking for SslStream
+ Socket.Blocking = true;
+ }
+
+ private void EndSslOperation()
+ {
+ // Restore blocking state
+ Socket.Blocking = _isBlockingSocket;
+ }
+
+ private void StartSslReadOperation()
+ {
+ StartSslOperation();
+
+ if (!_isBlockingSocket)
+ {
+ _previousReadTimeout = _stream.ReadTimeout;
+
+ _stream.ReadTimeout = 1;
+ }
+ }
+
+ private void EndSslReadOperation()
+ {
+ if (!_isBlockingSocket)
+ {
+ _stream.ReadTimeout = _previousReadTimeout;
+ }
+
+ EndSslOperation();
+ }
+
+ private static SslProtocols TranslateSslVersion(SslVersion version)
+ {
+ switch (version & SslVersion.VersionMask)
+ {
+ case SslVersion.Auto:
+ return SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13;
+ case SslVersion.TlsV10:
+ return SslProtocols.Tls;
+ case SslVersion.TlsV11:
+ return SslProtocols.Tls11;
+ case SslVersion.TlsV12:
+ return SslProtocols.Tls12;
+ case SslVersion.TlsV13:
+ return SslProtocols.Tls13;
+ default:
+ throw new NotImplementedException(version.ToString());
+ }
+ }
+
+ public ResultCode Handshake(string hostName)
+ {
+ StartSslOperation();
+ _stream = new SslStream(new NetworkStream(((ManagedSocket)Socket).Socket, false), false, null, null);
+ _stream.AuthenticateAsClient(hostName, null, TranslateSslVersion(_sslVersion), false);
+ EndSslOperation();
+
+ return ResultCode.Success;
+ }
+
+ public ResultCode Peek(out int peekCount, Memory<byte> buffer)
+ {
+ // NOTE: We cannot support that on .NET SSL API.
+ // As Nintendo's curl implementation detail check if a connection is alive via Peek, we just return that it would block to let it know that it's alive.
+ peekCount = -1;
+
+ return ResultCode.WouldBlock;
+ }
+
+ public int Pending()
+ {
+ // Unsupported
+ return 0;
+ }
+
+ private static bool TryTranslateWinSockError(bool isBlocking, WsaError error, out ResultCode resultCode)
+ {
+ switch (error)
+ {
+ case WsaError.WSAETIMEDOUT:
+ resultCode = isBlocking ? ResultCode.Timeout : ResultCode.WouldBlock;
+ return true;
+ case WsaError.WSAECONNABORTED:
+ resultCode = ResultCode.ConnectionAbort;
+ return true;
+ case WsaError.WSAECONNRESET:
+ resultCode = ResultCode.ConnectionReset;
+ return true;
+ default:
+ resultCode = ResultCode.Success;
+ return false;
+ }
+ }
+
+ public ResultCode Read(out int readCount, Memory<byte> buffer)
+ {
+ if (!Socket.Poll(0, SelectMode.SelectRead))
+ {
+ readCount = -1;
+
+ return ResultCode.WouldBlock;
+ }
+
+ StartSslReadOperation();
+
+ try
+ {
+ readCount = _stream.Read(buffer.Span);
+ }
+ catch (IOException exception)
+ {
+ readCount = -1;
+
+ if (exception.InnerException is SocketException socketException)
+ {
+ WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
+
+ if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
+ {
+ return result;
+ }
+ else
+ {
+ throw socketException;
+ }
+ }
+ else
+ {
+ throw exception;
+ }
+ }
+ finally
+ {
+ EndSslReadOperation();
+ }
+
+ return ResultCode.Success;
+ }
+
+ public ResultCode Write(out int writtenCount, ReadOnlyMemory<byte> buffer)
+ {
+ if (!Socket.Poll(0, SelectMode.SelectWrite))
+ {
+ writtenCount = 0;
+
+ return ResultCode.WouldBlock;
+ }
+
+ StartSslOperation();
+
+ try
+ {
+ _stream.Write(buffer.Span);
+ }
+ catch (IOException exception)
+ {
+ writtenCount = -1;
+
+ if (exception.InnerException is SocketException socketException)
+ {
+ WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
+
+ if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
+ {
+ return result;
+ }
+ else
+ {
+ throw socketException;
+ }
+ }
+ else
+ {
+ throw exception;
+ }
+ }
+ finally
+ {
+ EndSslOperation();
+ }
+
+ // .NET API doesn't provide the size written, assume all written.
+ writtenCount = buffer.Length;
+
+ return ResultCode.Success;
+ }
+
+ public ResultCode GetServerCertificate(string hostname, Span<byte> certificates, out uint storageSize, out uint certificateCount)
+ {
+ byte[] rawCertData = _stream.RemoteCertificate.GetRawCertData();
+
+ storageSize = (uint)rawCertData.Length;
+ certificateCount = 1;
+
+ if (rawCertData.Length > certificates.Length)
+ {
+ return ResultCode.CertBufferTooSmall;
+ }
+
+ rawCertData.CopyTo(certificates);
+
+ return ResultCode.Success;
+ }
+
+ public void Dispose()
+ {
+ _bsdContext.CloseFileDescriptor(SocketFd);
+ }
+ }
+}