diff options
Diffstat (limited to 'Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs')
| -rw-r--r-- | Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs | 454 |
1 files changed, 426 insertions, 28 deletions
diff --git a/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs b/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs index 24f3d066..fba22f45 100644 --- a/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs +++ b/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs @@ -1,41 +1,101 @@ using Ryujinx.Common.Logging; +using Ryujinx.HLE.Exceptions; +using Ryujinx.HLE.HOS.Services.Sockets.Bsd; using Ryujinx.HLE.HOS.Services.Ssl.Types; +using Ryujinx.Memory; +using System; using System.Text; namespace Ryujinx.HLE.HOS.Services.Ssl.SslService { - class ISslConnection : IpcService + class ISslConnection : IpcService, IDisposable { - public ISslConnection() { } + private bool _doNotClockSocket; + private bool _getServerCertChain; + private bool _skipDefaultVerify; + private bool _enableAlpn; + + private SslVersion _sslVersion; + private IoMode _ioMode; + private VerifyOption _verifyOption; + private SessionCacheMode _sessionCacheMode; + private string _hostName; + + private ISslConnectionBase _connection; + private BsdContext _bsdContext; + private readonly long _processId; + + private byte[] _nextAplnProto; + + public ISslConnection(long processId, SslVersion sslVersion) + { + _processId = processId; + _sslVersion = sslVersion; + _ioMode = IoMode.Blocking; + _sessionCacheMode = SessionCacheMode.None; + _verifyOption = VerifyOption.PeerCa | VerifyOption.HostName; + } [CommandHipc(0)] // SetSocketDescriptor(u32) -> u32 public ResultCode SetSocketDescriptor(ServiceCtx context) { - uint socketFd = context.RequestData.ReadUInt32(); - uint duplicateSocketFd = 0; + if (_connection != null) + { + return ResultCode.AlreadyInUse; + } + + _bsdContext = BsdContext.GetContext(_processId); + + if (_bsdContext == null) + { + return ResultCode.InvalidSocket; + } + + int inputFd = context.RequestData.ReadInt32(); - context.ResponseData.Write(duplicateSocketFd); + int internalFd = _bsdContext.DuplicateFileDescriptor(inputFd); - Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { socketFd }); + if (internalFd == -1) + { + return ResultCode.InvalidSocket; + } + + InitializeConnection(internalFd); + + int outputFd = inputFd; + + if (_doNotClockSocket) + { + outputFd = -1; + } + + context.ResponseData.Write(outputFd); return ResultCode.Success; } + private void InitializeConnection(int socketFd) + { + ISocket bsdSocket = _bsdContext.RetrieveSocket(socketFd); + + _connection = new SslManagedSocketConnection(_bsdContext, _sslVersion, socketFd, bsdSocket); + } + [CommandHipc(1)] // SetHostName(buffer<bytes, 5>) public ResultCode SetHostName(ServiceCtx context) { ulong hostNameDataPosition = context.Request.SendBuff[0].Position; - ulong hostNameDataSize = context.Request.SendBuff[0].Size; + ulong hostNameDataSize = context.Request.SendBuff[0].Size; byte[] hostNameData = new byte[hostNameDataSize]; context.Memory.Read(hostNameDataPosition, hostNameData); - string hostName = Encoding.ASCII.GetString(hostNameData).Trim('\0'); + _hostName = Encoding.ASCII.GetString(hostNameData).Trim('\0'); - Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { hostName }); + Logger.Info?.Print(LogClass.ServiceSsl, _hostName); return ResultCode.Success; } @@ -44,9 +104,9 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService // SetVerifyOption(nn::ssl::sf::VerifyOption) public ResultCode SetVerifyOption(ServiceCtx context) { - VerifyOption verifyOption = (VerifyOption)context.RequestData.ReadUInt32(); + _verifyOption = (VerifyOption)context.RequestData.ReadUInt32(); - Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { verifyOption }); + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _verifyOption }); return ResultCode.Success; } @@ -55,9 +115,67 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService // SetIoMode(nn::ssl::sf::IoMode) public ResultCode SetIoMode(ServiceCtx context) { - IoMode ioMode = (IoMode)context.RequestData.ReadUInt32(); + if (_connection == null) + { + return ResultCode.NoSocket; + } + + _ioMode = (IoMode)context.RequestData.ReadUInt32(); + + _connection.Socket.Blocking = _ioMode == IoMode.Blocking; + + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _ioMode }); + + return ResultCode.Success; + } + + [CommandHipc(4)] + // GetSocketDescriptor() -> u32 + public ResultCode GetSocketDescriptor(ServiceCtx context) + { + context.ResponseData.Write(_connection.SocketFd); + + return ResultCode.Success; + } + + [CommandHipc(5)] + // GetHostName(buffer<bytes, 6>) -> u32 + public ResultCode GetHostName(ServiceCtx context) + { + ulong hostNameDataPosition = context.Request.ReceiveBuff[0].Position; + ulong hostNameDataSize = context.Request.ReceiveBuff[0].Size; + + byte[] hostNameData = new byte[hostNameDataSize]; - Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { ioMode }); + Encoding.ASCII.GetBytes(_hostName, hostNameData); + + context.Memory.Write(hostNameDataPosition, hostNameData); + + context.ResponseData.Write((uint)_hostName.Length); + + Logger.Info?.Print(LogClass.ServiceSsl, _hostName); + + return ResultCode.Success; + } + + [CommandHipc(6)] + // GetVerifyOption() -> nn::ssl::sf::VerifyOption + public ResultCode GetVerifyOption(ServiceCtx context) + { + context.ResponseData.Write((uint)_verifyOption); + + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _verifyOption }); + + return ResultCode.Success; + } + + [CommandHipc(7)] + // GetIoMode() -> nn::ssl::sf::IoMode + public ResultCode GetIoMode(ServiceCtx context) + { + context.ResponseData.Write((uint)_ioMode); + + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _ioMode }); return ResultCode.Success; } @@ -66,32 +184,155 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService // DoHandshake() public ResultCode DoHandshake(ServiceCtx context) { - Logger.Stub?.PrintStub(LogClass.ServiceSsl); + if (_connection == null) + { + return ResultCode.NoSocket; + } - return ResultCode.Success; + return _connection.Handshake(_hostName); + } + + [CommandHipc(9)] + // DoHandshakeGetServerCert() -> (u32, u32, buffer<bytes, 6>) + public ResultCode DoHandshakeGetServerCert(ServiceCtx context) + { + if (_connection == null) + { + return ResultCode.NoSocket; + } + + ResultCode result = _connection.Handshake(_hostName); + + if (result == ResultCode.Success) + { + if (_getServerCertChain) + { + using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size)) + { + result = _connection.GetServerCertificate(_hostName, region.Memory.Span, out uint bufferSize, out uint certificateCount); + + context.ResponseData.Write(bufferSize); + context.ResponseData.Write(certificateCount); + } + } + else + { + context.ResponseData.Write(0); + context.ResponseData.Write(0); + } + } + + return result; + } + + [CommandHipc(10)] + // Read() -> (u32, buffer<bytes, 6>) + public ResultCode Read(ServiceCtx context) + { + if (_connection == null) + { + return ResultCode.NoSocket; + } + + ResultCode result; + + using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size)) + { + // TODO: Better error management. + result = _connection.Read(out int readCount, region.Memory); + + if (result == ResultCode.Success) + { + context.ResponseData.Write(readCount); + } + } + + return result; } [CommandHipc(11)] - // Write(buffer<bytes, 5>) -> u32 + // Write(buffer<bytes, 5>) -> s32 public ResultCode Write(ServiceCtx context) { - ulong inputDataPosition = context.Request.SendBuff[0].Position; - ulong inputDataSize = context.Request.SendBuff[0].Size; + if (_connection == null) + { + return ResultCode.NoSocket; + } - byte[] data = new byte[inputDataSize]; + // We don't dispose as this isn't supposed to be modified + WritableRegion region = context.Memory.GetWritableRegion(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size); - context.Memory.Read(inputDataPosition, data); + // TODO: Better error management. + ResultCode result = _connection.Write(out int writtenCount, region.Memory); - // NOTE: Tell the guest everything is transferred. - uint transferredSize = (uint)inputDataSize; + if (result == ResultCode.Success) + { + context.ResponseData.Write(writtenCount); + } - context.ResponseData.Write(transferredSize); + return result; + } + + [CommandHipc(12)] + // Pending() -> s32 + public ResultCode Pending(ServiceCtx context) + { + if (_connection == null) + { + return ResultCode.NoSocket; + } - Logger.Stub?.PrintStub(LogClass.ServiceSsl); + context.ResponseData.Write(_connection.Pending()); return ResultCode.Success; } + [CommandHipc(13)] + // Peek() -> (s32, buffer<bytes, 6>) + public ResultCode Peek(ServiceCtx context) + { + if (_connection == null) + { + return ResultCode.NoSocket; + } + + ResultCode result; + + using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size)) + { + // TODO: Better error management. + result = _connection.Peek(out int peekCount, region.Memory); + + if (result == ResultCode.Success) + { + context.ResponseData.Write(peekCount); + } + } + + return result; + } + + [CommandHipc(14)] + // Poll(nn::ssl::sf::PollEvent poll_event, u32 timeout) -> nn::ssl::sf::PollEvent + public ResultCode Poll(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + + [CommandHipc(15)] + // GetVerifyCertError() + public ResultCode GetVerifyCertError(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + + [CommandHipc(16)] + // GetNeededServerCertBufferSize() -> u32 + public ResultCode GetNeededServerCertBufferSize(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + [CommandHipc(17)] // SetSessionCacheMode(nn::ssl::sf::SessionCacheMode) public ResultCode SetSessionCacheMode(ServiceCtx context) @@ -100,19 +341,176 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { sessionCacheMode }); + _sessionCacheMode = sessionCacheMode; + return ResultCode.Success; } + [CommandHipc(18)] + // GetSessionCacheMode() -> nn::ssl::sf::SessionCacheMode + public ResultCode GetSessionCacheMode(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + + [CommandHipc(19)] + // FlushSessionCache() + public ResultCode FlushSessionCache(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + + [CommandHipc(20)] + // SetRenegotiationMode(nn::ssl::sf::RenegotiationMode) + public ResultCode SetRenegotiationMode(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + + [CommandHipc(21)] + // GetRenegotiationMode() -> nn::ssl::sf::RenegotiationMode + public ResultCode GetRenegotiationMode(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + [CommandHipc(22)] - // SetOption(b8, nn::ssl::sf::OptionType) + // SetOption(b8 value, nn::ssl::sf::OptionType option) public ResultCode SetOption(ServiceCtx context) { - bool optionEnabled = context.RequestData.ReadBoolean(); - OptionType optionType = (OptionType)context.RequestData.ReadUInt32(); + bool value = context.RequestData.ReadUInt32() != 0; + OptionType option = (OptionType)context.RequestData.ReadUInt32(); + + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { option, value }); + + return SetOption(option, value); + } + + [CommandHipc(23)] + // GetOption(nn::ssl::sf::OptionType) -> b8 + public ResultCode GetOption(ServiceCtx context) + { + OptionType option = (OptionType)context.RequestData.ReadUInt32(); + + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { option }); + + ResultCode result = GetOption(option, out bool value); + + if (result == ResultCode.Success) + { + context.ResponseData.Write(value); + } + + return result; + } + + [CommandHipc(24)] + // GetVerifyCertErrors() -> (u32, u32, buffer<bytes, 6>) + public ResultCode GetVerifyCertErrors(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } - Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { optionType, optionEnabled }); + [CommandHipc(25)] // 4.0.0+ + // GetCipherInfo(u32) -> buffer<bytes, 6> + public ResultCode GetCipherInfo(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + + [CommandHipc(26)] + // SetNextAlpnProto(buffer<bytes, 5>) -> u32 + public ResultCode SetNextAlpnProto(ServiceCtx context) + { + ulong inputDataPosition = context.Request.SendBuff[0].Position; + ulong inputDataSize = context.Request.SendBuff[0].Size; + + _nextAplnProto = new byte[inputDataSize]; + + context.Memory.Read(inputDataPosition, _nextAplnProto); + + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { inputDataSize }); + + return ResultCode.Success; + } + + [CommandHipc(27)] + // GetNextAlpnProto(buffer<bytes, 6>) -> u32 + public ResultCode GetNextAlpnProto(ServiceCtx context) + { + ulong outputDataPosition = context.Request.ReceiveBuff[0].Position; + ulong outputDataSize = context.Request.ReceiveBuff[0].Size; + + context.Memory.Write(outputDataPosition, _nextAplnProto); + + context.ResponseData.Write(_nextAplnProto.Length); + + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { outputDataSize }); return ResultCode.Success; } + + private ResultCode SetOption(OptionType option, bool value) + { + switch (option) + { + case OptionType.DoNotCloseSocket: + _doNotClockSocket = value; + break; + + case OptionType.GetServerCertChain: + _getServerCertChain = value; + break; + + case OptionType.SkipDefaultVerify: + _skipDefaultVerify = value; + break; + + case OptionType.EnableAlpn: + _enableAlpn = value; + break; + + default: + Logger.Warning?.Print(LogClass.ServiceSsl, $"Unsupported option {option}"); + return ResultCode.InvalidOption; + } + + return ResultCode.Success; + } + + private ResultCode GetOption(OptionType option, out bool value) + { + switch (option) + { + case OptionType.DoNotCloseSocket: + value = _doNotClockSocket; + break; + + case OptionType.GetServerCertChain: + value = _getServerCertChain; + break; + + case OptionType.SkipDefaultVerify: + value = _skipDefaultVerify; + break; + + case OptionType.EnableAlpn: + value = _enableAlpn; + break; + + default: + Logger.Warning?.Print(LogClass.ServiceSsl, $"Unsupported option {option}"); + + value = false; + return ResultCode.InvalidOption; + } + + return ResultCode.Success; + } + + public void Dispose() + { + _connection?.Dispose(); + } } }
\ No newline at end of file |
