Compare commits

...

12 Commits

Author SHA1 Message Date
TSRBerry 583f9c6bd9
Merge 1ad8bea798 into a23d8cb92f 2024-05-10 13:23:49 +00:00
Marco Carvalho a23d8cb92f
Replace "List.ForEach" for "foreach" (#6783)
* Replace "List.ForEach" for "foreach"

* dotnet format

* Update Ptc.cs

* Update GpuContext.cs
2024-05-08 13:53:25 +02:00
TSR Berry 1ad8bea798
Apply dotnet format whitespace 2024-04-21 00:42:58 +02:00
TSR Berry 38f9ee73b5
Add a test for ManagedProxySocket.Bind 2024-04-21 00:42:57 +02:00
TSR Berry 057ab9a2c7
Add a mock socks5 server for tests 2024-04-21 00:42:57 +02:00
TSR Berry a2867f473e
Ensure requests and responses are correctly encoded 2024-04-21 00:42:57 +02:00
TSR Berry c42241d834
Add a few simple tests
Ensure the struct sizes are correct
2024-04-21 00:42:57 +02:00
TSR Berry 734db17bb9
Rename SocksUdpIpv4Header to SocksIpv4UdpHeader 2024-04-21 00:42:57 +02:00
TSR Berry 6319912b58
Apply formatting 2024-04-21 00:42:57 +02:00
TSR Berry 79df779b52
Make proxy socket compatible with .NET 8 2024-04-21 00:42:57 +02:00
TSR Berry 5b54e72b2b
sockets: Add proxy socket 2024-04-21 00:42:56 +02:00
TSR Berry 8b832f75ab
sockets: Rename Refcount to RefCount 2024-04-21 00:42:56 +02:00
28 changed files with 1613 additions and 19 deletions

View File

@ -857,8 +857,14 @@ namespace ARMeilleure.Translation.PTC
Stopwatch sw = Stopwatch.StartNew();
threads.ForEach((thread) => thread.Start());
threads.ForEach((thread) => thread.Join());
foreach (var thread in threads)
{
thread.Start();
}
foreach (var thread in threads)
{
thread.Join();
}
threads.Clear();

View File

@ -395,8 +395,14 @@ namespace Ryujinx.Graphics.Gpu
{
Renderer.CreateSync(SyncNumber, strict);
SyncActions.ForEach(action => action.SyncPreAction(syncpoint));
SyncpointActions.ForEach(action => action.SyncPreAction(syncpoint));
foreach (var action in SyncActions)
{
action.SyncPreAction(syncpoint);
}
foreach (var action in SyncpointActions)
{
action.SyncPreAction(syncpoint);
}
SyncNumber++;

View File

@ -103,7 +103,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
{
lock (_lock)
{
oldFile.Refcount++;
oldFile.RefCount++;
return RegisterFileDescriptor(oldFile);
}
@ -118,9 +118,9 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
if (file != null)
{
file.Refcount--;
file.RefCount--;
if (file.Refcount <= 0)
if (file.RefCount <= 0)
{
file.Dispose();
}

View File

@ -1,6 +1,7 @@
using Ryujinx.Common;
using Ryujinx.Common.Logging;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using Ryujinx.Memory;
using System;
@ -95,10 +96,8 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
}
}
ISocket newBsdSocket = new ManagedSocket(netDomain, (SocketType)type, protocol)
{
Blocking = !creationFlags.HasFlag(BsdSocketCreationFlags.NonBlocking),
};
ISocket newBsdSocket = ProxyManager.GetSocket(netDomain, (SocketType)type, protocol);
newBsdSocket.Blocking = !creationFlags.HasFlag(BsdSocketCreationFlags.NonBlocking);
LinuxError errno = LinuxError.SUCCESS;

View File

@ -6,7 +6,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
interface IFileDescriptor : IDisposable
{
bool Blocking { get; set; }
int Refcount { get; set; }
int RefCount { get; set; }
LinuxError Read(out int readSize, Span<byte> buffer);

View File

@ -32,7 +32,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
UpdateEventStates();
}
public int Refcount { get; set; }
public int RefCount { get; set; }
public void Dispose()
{

View File

@ -0,0 +1,733 @@
using Ryujinx.Common.Logging;
using Ryujinx.Common.Memory;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Auth;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Packets;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using System;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Security.Authentication;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{
class ManagedProxySocket : ISocket
{
private static readonly IPEndPoint _endpointZero = new(IPAddress.Any, 0);
private readonly EndPoint _proxyEndpoint;
private IProxyAuth _proxyAuth;
private bool _ready;
private IPEndPoint _udpEndpoint;
private Socket _udpSocket;
public Socket Socket { get; private set; } = new(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
public int RefCount { get; set; }
public AddressFamily AddressFamily { get; }
public SocketType SocketType { get; }
public ProtocolType ProtocolType { get; }
public bool Blocking
{
get
{
return _udpEndpoint != null ? _udpSocket.Blocking : Socket.Blocking;
}
set
{
if (_udpEndpoint != null)
{
_udpSocket.Blocking = value;
}
else
{
Socket.Blocking = value;
}
}
}
public IntPtr Handle => _udpEndpoint != null ? _udpSocket.Handle : Socket.Handle;
public IPEndPoint RemoteEndPoint { get; private set; }
public IPEndPoint LocalEndPoint { get; private set; }
public ManagedProxySocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType, EndPoint proxyEndpoint)
{
AddressFamily = addressFamily;
SocketType = socketType;
ProtocolType = protocolType;
_proxyEndpoint = proxyEndpoint;
RefCount = 1;
}
private ManagedProxySocket(ManagedProxySocket oldSocket)
{
AddressFamily = oldSocket.AddressFamily;
SocketType = oldSocket.SocketType;
ProtocolType = oldSocket.ProtocolType;
LocalEndPoint = oldSocket.LocalEndPoint;
RemoteEndPoint = oldSocket.RemoteEndPoint;
_proxyEndpoint = oldSocket._proxyEndpoint;
Socket = oldSocket.Socket;
RefCount = 1;
}
#region Proxy methods
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void EnsureVersionIsValid(byte version)
{
if (version != ProxyConsts.Version)
{
throw new InvalidDataException($"Invalid proxy version: {version}");
}
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void EnsureSuccessReply(ReplyField replyField)
{
if (replyField != ReplyField.Succeeded)
{
throw new ProxyException(replyField);
}
}
private TResp SendAndReceive<TReq, TResp>(TReq request)
where TReq : unmanaged
where TResp : unmanaged
{
byte[] requestData = new byte[Marshal.SizeOf<TReq>()];
byte[] responseData = new byte[Marshal.SizeOf<TResp>() + (_proxyAuth?.WrapperLength ?? 0)];
MemoryMarshal.Write(requestData, request);
int expectedSentBytes;
int sentBytes;
if (_proxyAuth != null)
{
expectedSentBytes = requestData.Length + _proxyAuth.WrapperLength;
sentBytes = Socket.Send(_proxyAuth.Wrap(requestData));
}
else
{
expectedSentBytes = requestData.Length;
sentBytes = Socket.Send(requestData);
}
if (sentBytes < expectedSentBytes)
{
throw new InvalidOperationException($"Failed to send the full proxy request: {sentBytes} of {expectedSentBytes} bytes");
}
int expectedReceivedBytes = responseData.Length;
int receivedBytes = Socket.Receive(responseData);
if (receivedBytes < expectedReceivedBytes)
{
throw new InvalidOperationException($"Proxy response size is invalid. Expected {expectedReceivedBytes} bytes, got {receivedBytes}.");
}
if (_proxyAuth != null)
{
return MemoryMarshal.Read<TResp>(_proxyAuth.Unwrap(responseData));
}
else
{
return MemoryMarshal.Read<TResp>(responseData);
}
}
/// <summary>
/// Get the authentication method chosen by the server.
/// </summary>
private AuthMethod GetAuthenticationMethod()
{
var response = SendAndReceive<MethodSelectionRequest1, MethodSelectionResponse>(new MethodSelectionRequest1
{
Version = ProxyConsts.Version,
NumOfMethods = 1,
Methods = new Array1<AuthMethod> { [0] = AuthMethod.NoAuthenticationRequired },
});
EnsureVersionIsValid(response.Version);
return response.Method;
}
/// <summary>
/// Authenticate to the server using a method-specific sub-negotiation.
/// </summary>
/// <param name="method">The authentication method to use.</param>
/// <exception cref="NotImplementedException">The provided authentication method is not implemented.</exception>
/// <exception cref="AuthenticationException">Authentication failed.</exception>
/// <exception cref="ArgumentOutOfRangeException">The provided authentication method is invalid.</exception>
private void Authenticate(AuthMethod method)
{
switch (method)
{
case AuthMethod.NoAuthenticationRequired:
case AuthMethod.GSSAPI:
case AuthMethod.UsernameAndPassword:
_proxyAuth = method.GetAuth();
_proxyAuth.Authenticate();
return;
case AuthMethod.NoAcceptableMethods:
throw new AuthenticationException("No acceptable authentication method found.");
default:
throw new ArgumentOutOfRangeException(nameof(method), method, null);
}
}
/// <summary>
/// Connect to a remote endpoint.
/// </summary>
/// <remarks>
/// In the response from the proxy server
/// <see cref="SocksIpv4Response.BoundAddress"/> maps to the associated IP address,
/// while <see cref="SocksIpv4Response.BoundPort"/> maps to the port assigned to connect to the target host.
/// </remarks>
/// <param name="endpoint">The endpoint to connect to.</param>
/// <returns>The endpoint the server assigned to connect to the target host.</returns>
/// <exception cref="ProxyException">The connection to the specified endpoint failed.</exception>
private IPEndPoint ProxyConnect(IPEndPoint endpoint)
{
Socket.Connect(_proxyEndpoint);
Authenticate(GetAuthenticationMethod());
var response = SendAndReceive<SocksIpv4Request, SocksIpv4Response>(new SocksIpv4Request
{
Version = ProxyConsts.Version,
Command = ProxyCommand.Connect,
Reserved = 0x00,
AddressType = AddressType.Ipv4Address,
DestinationAddress = endpoint.Address,
DestinationPort = (ushort)endpoint.Port,
});
EnsureVersionIsValid(response.Version);
EnsureSuccessReply(response.ReplyField);
_ready = true;
return new IPEndPoint(response.BoundAddress, response.BoundPort);
}
/// <summary>
/// Listen for an incoming connection from the specified endpoint.
/// The specified endpoint may be 0 if it's not known beforehand.
/// </summary>
/// <remarks>
/// The specified endpoint is only used to restrict
/// which clients are allowed to connect to the endpoint associated to this request.
/// </remarks>
/// <param name="endpoint">The endpoint of the incoming connection.</param>
/// <returns>The endpoint the server uses to listen for an incoming connection.</returns>
private IPEndPoint ProxyBind(IPEndPoint endpoint)
{
Socket.Connect(_proxyEndpoint);
Authenticate(GetAuthenticationMethod());
var response = SendAndReceive<SocksIpv4Request, SocksIpv4Response>(new SocksIpv4Request
{
Version = ProxyConsts.Version,
Command = ProxyCommand.Bind,
Reserved = 0x00,
AddressType = AddressType.Ipv4Address,
DestinationAddress = endpoint.Address,
DestinationPort = (ushort)endpoint.Port,
});
EnsureVersionIsValid(response.Version);
EnsureSuccessReply(response.ReplyField);
return new IPEndPoint(response.BoundAddress, response.BoundPort);
}
/// <summary>
/// Get the anticipated incoming connection.
/// </summary>
/// <returns>The endpoint of the incoming connection.</returns>
/// <exception cref="InvalidOperationException">The response length is too small.</exception>
private IPEndPoint WaitForIncomingConnection()
{
byte[] responseData = new byte[Marshal.SizeOf<SocksIpv4Response>() + _proxyAuth.WrapperLength];
int expectedReceivedBytes = responseData.Length;
int receivedBytes = Socket.Receive(responseData);
if (receivedBytes < expectedReceivedBytes)
{
throw new InvalidOperationException($"Proxy response size is invalid. Expected {expectedReceivedBytes} bytes, got {receivedBytes}.");
}
var response = MemoryMarshal.Read<SocksIpv4Response>(_proxyAuth.Unwrap(responseData));
EnsureVersionIsValid(response.Version);
EnsureSuccessReply(response.ReplyField);
_ready = true;
return new IPEndPoint(response.BoundAddress, response.BoundPort);
}
/// <summary>
/// Create a UDP relay.
/// The specified endpoint may be 0 if it's not known beforehand.
/// </summary>
/// <remarks>
/// The specified endpoint is only used to restrict which clients are allowed to use the relay.
/// </remarks>
/// <param name="endpoint">The endpoint used to send UDP datagrams to the relay.</param>
private void AssociateUdp(IPEndPoint endpoint)
{
Socket.Connect(_proxyEndpoint);
Authenticate(GetAuthenticationMethod());
var response = SendAndReceive<SocksIpv4Request, SocksIpv4Response>(new SocksIpv4Request
{
Version = ProxyConsts.Version,
Command = ProxyCommand.UdpAssociate,
Reserved = 0x00,
AddressType = AddressType.Ipv4Address,
DestinationAddress = endpoint.Address,
DestinationPort = (ushort)endpoint.Port,
});
EnsureVersionIsValid(response.Version);
EnsureSuccessReply(response.ReplyField);
_udpEndpoint = new IPEndPoint(response.BoundAddress, response.BoundPort);
_udpSocket = new Socket(AddressFamily, SocketType, ProtocolType)
{
Blocking = Socket.Blocking,
};
_udpSocket.Bind(endpoint);
_ready = true;
}
#endregion
public LinuxError Send(out int sendSize, ReadOnlySpan<byte> buffer, BsdSocketFlags flags)
{
if (!_ready)
{
throw new InvalidOperationException("No connection has been established. Issue a proxy command before sending data.");
}
if (_udpEndpoint != null)
{
throw new InvalidOperationException($"UDP packets can only be sent using {nameof(SendTo)}.");
}
try
{
sendSize = Socket.Send(_proxyAuth.Wrap(buffer), ManagedSocket.ConvertBsdSocketFlags(flags)) - _proxyAuth.WrapperLength;
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
sendSize = -1;
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public LinuxError SendTo(out int sendSize, ReadOnlySpan<byte> buffer, int size, BsdSocketFlags flags, IPEndPoint remoteEndPoint)
{
if (!_ready || _udpEndpoint == null)
{
throw new InvalidOperationException("No connection has been established. Issue a proxy command before sending data.");
}
byte[] data = new byte[Marshal.SizeOf<SocksIpv4UdpHeader>() + buffer.Length];
var header = new SocksIpv4UdpHeader
{
Reserved = 0,
Fragment = 0,
AddressType = AddressType.Ipv4Address,
DestinationAddress = remoteEndPoint.Address,
DestinationPort = (ushort)remoteEndPoint.Port,
};
MemoryMarshal.Write(data, header);
buffer[..size].CopyTo(data.AsSpan()[Marshal.SizeOf<SocksIpv4UdpHeader>()..]);
try
{
sendSize = _udpSocket.SendTo(_proxyAuth.Wrap(data), _udpEndpoint) -
Marshal.SizeOf<SocksIpv4UdpHeader>() - _proxyAuth.WrapperLength;
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
sendSize = -1;
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public LinuxError Receive(out int receiveSize, Span<byte> buffer, BsdSocketFlags flags)
{
LinuxError result;
bool shouldBlockAfterOperation = false;
if (Blocking && flags.HasFlag(BsdSocketFlags.DontWait))
{
Blocking = false;
shouldBlockAfterOperation = true;
}
byte[] data = new byte[buffer.Length + _proxyAuth.WrapperLength];
try
{
receiveSize = Socket.Receive(data) - _proxyAuth.WrapperLength;
_proxyAuth.Unwrap(data).CopyTo(buffer);
result = LinuxError.SUCCESS;
}
catch (SocketException exception)
{
receiveSize = -1;
result = WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
if (shouldBlockAfterOperation)
{
Blocking = true;
}
return result;
}
public LinuxError ReceiveFrom(out int receiveSize, Span<byte> buffer, int size, BsdSocketFlags flags, out IPEndPoint remoteEndPoint)
{
LinuxError result;
remoteEndPoint = new IPEndPoint(IPAddress.Any, 0);
bool shouldBlockAfterOperation = false;
byte[] data = new byte[size + _proxyAuth.WrapperLength + Marshal.SizeOf<SocksIpv4UdpHeader>()];
EndPoint udpEndpoint = _udpEndpoint;
if (_udpSocket is not { IsBound: true })
{
receiveSize = -1;
return LinuxError.EOPNOTSUPP;
}
if (Blocking && flags.HasFlag(BsdSocketFlags.DontWait))
{
Blocking = false;
shouldBlockAfterOperation = true;
}
try
{
receiveSize = _udpSocket.ReceiveFrom(data, ref udpEndpoint) - _proxyAuth.WrapperLength - Marshal.SizeOf<SocksIpv4UdpHeader>();
data = _proxyAuth.Unwrap(data).ToArray();
var header = MemoryMarshal.Read<SocksIpv4UdpHeader>(data);
// An implementation that doesn't support fragmentation must drop any fragmented datagram
// TODO: Implement support for fragmentation
if (header.Fragment != 0)
{
if (shouldBlockAfterOperation)
{
Blocking = true;
}
receiveSize = -1;
return LinuxError.EOPNOTSUPP;
}
remoteEndPoint = new IPEndPoint(header.DestinationAddress, header.DestinationPort);
data.AsSpan()[Marshal.SizeOf<SocksIpv4UdpHeader>()..].CopyTo(buffer[..size]);
result = LinuxError.SUCCESS;
}
catch (SocketException exception)
{
receiveSize = -1;
result = WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
if (shouldBlockAfterOperation)
{
Blocking = true;
}
return result;
}
public LinuxError Bind(IPEndPoint localEndPoint)
{
switch (ProtocolType)
{
case ProtocolType.Tcp:
try
{
Socket.Bind(localEndPoint);
LocalEndPoint = ProxyBind(_endpointZero);
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
case ProtocolType.Udp:
try
{
AssociateUdp(localEndPoint);
LocalEndPoint = localEndPoint;
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
default:
return LinuxError.EOPNOTSUPP;
}
}
public LinuxError Listen(int backlog)
{
if (backlog > 1)
{
return LinuxError.EOPNOTSUPP;
}
return LinuxError.SUCCESS;
}
public LinuxError Accept(out ISocket newSocket)
{
try
{
RemoteEndPoint = WaitForIncomingConnection();
newSocket = new ManagedProxySocket(this);
LocalEndPoint = null;
RemoteEndPoint = null;
_ready = false;
Socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
newSocket = null;
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public LinuxError Connect(IPEndPoint remoteEndPoint)
{
try
{
LocalEndPoint = ProxyConnect(remoteEndPoint);
RemoteEndPoint = remoteEndPoint;
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
if (!Blocking && exception.ErrorCode == (int)WsaError.WSAEWOULDBLOCK)
{
return LinuxError.EINPROGRESS;
}
else
{
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
}
public bool Poll(int microSeconds, SelectMode mode)
{
if (_udpEndpoint != null)
{
return _udpSocket.Poll(microSeconds, mode);
}
else
{
return Socket.Poll(microSeconds, mode);
}
}
public LinuxError GetSocketOption(BsdSocketOption option, SocketOptionLevel level, Span<byte> optionValue)
{
try
{
LinuxError result = WinSockHelper.ValidateSocketOption(option, level, write: false);
if (result != LinuxError.SUCCESS)
{
Logger.Warning?.Print(LogClass.ServiceBsd, $"Invalid GetSockOpt Option: {option} Level: {level}");
return result;
}
if (!WinSockHelper.TryConvertSocketOption(option, level, out SocketOptionName optionName))
{
Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported GetSockOpt Option: {option} Level: {level}");
optionValue.Clear();
return LinuxError.SUCCESS;
}
byte[] tempOptionValue = new byte[optionValue.Length];
if (_udpEndpoint != null)
{
_udpSocket.GetSocketOption(level, optionName, tempOptionValue);
}
else
{
Socket.GetSocketOption(level, optionName, tempOptionValue);
}
tempOptionValue.AsSpan().CopyTo(optionValue);
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public LinuxError SetSocketOption(BsdSocketOption option, SocketOptionLevel level, ReadOnlySpan<byte> optionValue)
{
try
{
LinuxError result = WinSockHelper.ValidateSocketOption(option, level, write: true);
if (result != LinuxError.SUCCESS)
{
Logger.Warning?.Print(LogClass.ServiceBsd, $"Invalid SetSockOpt Option: {option} Level: {level}");
return result;
}
if (!WinSockHelper.TryConvertSocketOption(option, level, out SocketOptionName optionName))
{
Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported SetSockOpt Option: {option} Level: {level}");
return LinuxError.SUCCESS;
}
int value = optionValue.Length >= 4 ? MemoryMarshal.Read<int>(optionValue) : MemoryMarshal.Read<byte>(optionValue);
if (level == SocketOptionLevel.Socket && option == BsdSocketOption.SoLinger)
{
int value2 = 0;
if (optionValue.Length >= 8)
{
value2 = MemoryMarshal.Read<int>(optionValue[4..]);
}
if (_udpEndpoint != null)
{
_udpSocket.SetSocketOption(level, SocketOptionName.Linger, new LingerOption(value != 0, value2));
}
else
{
Socket.SetSocketOption(level, SocketOptionName.Linger, new LingerOption(value != 0, value2));
}
}
else
{
if (_udpEndpoint != null)
{
_udpSocket.SetSocketOption(level, optionName, value);
}
else
{
Socket.SetSocketOption(level, optionName, value);
}
}
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public LinuxError Read(out int readSize, Span<byte> buffer)
{
return Receive(out readSize, buffer, BsdSocketFlags.None);
}
public LinuxError Write(out int writeSize, ReadOnlySpan<byte> buffer)
{
return Send(out writeSize, buffer, BsdSocketFlags.None);
}
public LinuxError Shutdown(BsdSocketShutdownFlags how)
{
try
{
_udpSocket?.Shutdown((SocketShutdown)how);
Socket.Shutdown((SocketShutdown)how);
return LinuxError.SUCCESS;
}
catch (SocketException exception)
{
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
}
}
public void Disconnect()
{
Socket.Disconnect(true);
_udpEndpoint = null;
RemoteEndPoint = null;
LocalEndPoint = _endpointZero;
_ready = false;
}
public void Close()
{
_udpSocket?.Close();
Socket.Close();
}
public void Dispose()
{
_udpSocket?.Close();
_udpSocket?.Dispose();
Socket.Close();
Socket.Dispose();
}
public LinuxError RecvMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags, TimeVal timeout)
{
throw new NotImplementedException();
}
public LinuxError SendMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags)
{
throw new NotImplementedException();
}
}
}

View File

@ -11,7 +11,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
{
class ManagedSocket : ISocket
{
public int Refcount { get; set; }
public int RefCount { get; set; }
public AddressFamily AddressFamily => Socket.AddressFamily;
@ -32,16 +32,16 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl
public ManagedSocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
Socket = new Socket(addressFamily, socketType, protocolType);
Refcount = 1;
RefCount = 1;
}
private ManagedSocket(Socket socket)
{
Socket = socket;
Refcount = 1;
RefCount = 1;
}
private static SocketFlags ConvertBsdSocketFlags(BsdSocketFlags bsdSocketFlags)
internal static SocketFlags ConvertBsdSocketFlags(BsdSocketFlags bsdSocketFlags)
{
SocketFlags socketFlags = SocketFlags.None;

View File

@ -0,0 +1,10 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy
{
public enum AddressType : byte
{
Ipv4Address = 0x01,
// TODO: Implement support for DomainName and IPv6 addresses to be SOCKS5 compliant
DomainName = 0x03,
Ipv6Address,
}
}

View File

@ -0,0 +1,32 @@
using System;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Auth
{
public enum AuthMethod : byte
{
NoAuthenticationRequired,
GSSAPI,
UsernameAndPassword,
// 0x03 - 0x7F: IANA assigned
// 0x80 - 0xFE: Reserved for private methods
NoAcceptableMethods = 0xFF,
}
public static class AuthMethodExtensions
{
public static IProxyAuth GetAuth(this AuthMethod authMethod)
{
return authMethod switch
{
AuthMethod.NoAuthenticationRequired => new NoAuthentication(),
// TODO: Implement GSSAPI to be SOCKS5 compliant
AuthMethod.GSSAPI => throw new NotImplementedException(
$"Authentication method not implemented: {authMethod}"),
AuthMethod.UsernameAndPassword => throw new NotImplementedException(
$"Authentication method not implemented: {authMethod}"),
_ => throw new ArgumentException($"Invalid authentication method provided: {authMethod}",
nameof(authMethod)),
};
}
}
}

View File

@ -0,0 +1,33 @@
using System;
using System.Security.Authentication;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Auth
{
public interface IProxyAuth
{
/// <summary>
/// The amount of additional bytes required for a wrapped packet;
/// </summary>
public int WrapperLength { get; }
/// <summary>
/// Authenticate to the server using a method-specific sub-negotiation.
/// </summary>
/// <exception cref="AuthenticationException">Authentication failed.</exception>
public void Authenticate();
/// <summary>
/// Wrap the packet as required by the negotiated authentication method.
/// </summary>
/// <param name="packet">The packet to wrap.</param>
/// <returns>The wrapped packet.</returns>
public ReadOnlySpan<byte> Wrap(ReadOnlySpan<byte> packet);
/// <summary>
/// Unwrap the packet and perform the checks as required by the negotiated authentication method.
/// </summary>
/// <param name="packet">The packet to unwrap.</param>
/// <returns>The unwrapped packet.</returns>
public ReadOnlySpan<byte> Unwrap(ReadOnlySpan<byte> packet);
}
}

View File

@ -0,0 +1,24 @@
using System;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Auth
{
public class NoAuthentication : IProxyAuth
{
public int WrapperLength => 0;
public void Authenticate()
{
// Nothing to do here.
}
public ReadOnlySpan<byte> Wrap(ReadOnlySpan<byte> packet)
{
return packet;
}
public ReadOnlySpan<byte> Unwrap(ReadOnlySpan<byte> packet)
{
return packet;
}
}
}

View File

@ -0,0 +1,12 @@
using Ryujinx.Common.Memory;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Auth;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Packets
{
public struct MethodSelectionRequest1
{
public byte Version;
public byte NumOfMethods;
public Array1<AuthMethod> Methods;
}
}

View File

@ -0,0 +1,10 @@
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Auth;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Packets
{
public struct MethodSelectionResponse
{
public byte Version;
public AuthMethod Method;
}
}

View File

@ -0,0 +1,40 @@
using System;
using System.Net;
using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Packets
{
[StructLayout(LayoutKind.Sequential, Pack = 1)]
public struct SocksIpv4Request
{
public byte Version;
public ProxyCommand Command;
public byte Reserved;
// NOTE: Must be AddressType.Ipv4Address
public AddressType AddressType;
private uint _destinationAddress;
private ushort _destinationPort;
public IPAddress DestinationAddress
{
readonly get => new(BitConverter.GetBytes(_destinationAddress));
set => _destinationAddress = BitConverter.ToUInt32(value.GetAddressBytes());
}
public ushort DestinationPort
{
readonly get
{
byte[] portBytes = BitConverter.GetBytes(_destinationPort);
Array.Reverse(portBytes);
return BitConverter.ToUInt16(portBytes);
}
set
{
byte[] portBytes = BitConverter.GetBytes(value);
Array.Reverse(portBytes);
_destinationPort = BitConverter.ToUInt16(portBytes);
}
}
}
}

View File

@ -0,0 +1,40 @@
using System;
using System.Net;
using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Packets
{
[StructLayout(LayoutKind.Sequential, Pack = 1)]
public struct SocksIpv4Response
{
public byte Version;
public ReplyField ReplyField;
public byte Reserved;
// NOTE: Must be AddressType.Ipv4Address
public AddressType AddressType;
private uint _boundAddress;
private ushort _boundPort;
public IPAddress BoundAddress
{
readonly get => new(BitConverter.GetBytes(_boundAddress));
set => _boundAddress = BitConverter.ToUInt32(value.GetAddressBytes());
}
public ushort BoundPort
{
readonly get
{
byte[] portBytes = BitConverter.GetBytes(_boundPort);
Array.Reverse(portBytes);
return BitConverter.ToUInt16(portBytes);
}
set
{
byte[] portBytes = BitConverter.GetBytes(value);
Array.Reverse(portBytes);
_boundPort = BitConverter.ToUInt16(portBytes);
}
}
}
}

View File

@ -0,0 +1,39 @@
using System;
using System.Net;
using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Packets
{
[StructLayout(LayoutKind.Sequential, Pack = 1)]
public struct SocksIpv4UdpHeader
{
public ushort Reserved;
public byte Fragment;
// NOTE: Must be AddressType.Ipv4Address
public AddressType AddressType;
private uint _destinationAddress;
private ushort _destinationPort;
public IPAddress DestinationAddress
{
readonly get => new(BitConverter.GetBytes(_destinationAddress));
set => _destinationAddress = BitConverter.ToUInt32(value.GetAddressBytes());
}
public ushort DestinationPort
{
readonly get
{
byte[] portBytes = BitConverter.GetBytes(_destinationPort);
Array.Reverse(portBytes);
return BitConverter.ToUInt16(portBytes);
}
set
{
byte[] portBytes = BitConverter.GetBytes(value);
Array.Reverse(portBytes);
_destinationPort = BitConverter.ToUInt16(portBytes);
}
}
}
}

View File

@ -0,0 +1,9 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy
{
public enum ProxyCommand : byte
{
Connect = 0x01,
Bind,
UdpAssociate,
}
}

View File

@ -0,0 +1,7 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy
{
public static class ProxyConsts
{
public const byte Version = 0x05;
}
}

View File

@ -0,0 +1,38 @@
using System;
using System.Collections.Generic;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy
{
public class ProxyException : Exception
{
private static readonly Dictionary<ReplyField, string> _exceptionMessages = new()
{
{ ReplyField.ServerFailure, "The proxy server failed to process this request." },
{ ReplyField.ConnectionNotAllowed, "The proxy server did not allow the connection." },
{ ReplyField.NetworkUnreachable, "The target network is unreachable." },
{ ReplyField.HostUnreachable, "The target is unreachable." },
{ ReplyField.ConnectionRefused, "The target refused the connection" },
{ ReplyField.TTLExpired, "The TTL expired before reaching the target." },
{ ReplyField.CommandNotSupported, "The specified command is not supported." },
{ ReplyField.AddressTypeNotSupported, "The specified address type is not supported." },
};
public ReplyField ReplyCode { get; }
public ProxyException(ReplyField replyField) : base($"{_exceptionMessages[replyField]} ({replyField})")
{
ReplyCode = replyField;
}
public ProxyException(ReplyField replyField, string message) : base(message)
{
ReplyCode = replyField;
}
public ProxyException(ReplyField replyField, string message, Exception innerException) : base(message,
innerException)
{
ReplyCode = replyField;
}
}
}

View File

@ -0,0 +1,52 @@
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl;
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy
{
public static class ProxyManager
{
private static readonly Dictionary<string, EndPoint> _proxyEndpoints = new();
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static string GetKey(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
return string.Join("-", new[] { (int)addressFamily, (int)socketType, (int)protocolType });
}
internal static ISocket GetSocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
if (_proxyEndpoints.TryGetValue(GetKey(addressFamily, socketType, protocolType), out EndPoint endPoint))
{
return new ManagedProxySocket(addressFamily, socketType, protocolType, endPoint);
}
return new ManagedSocket(addressFamily, socketType, protocolType);
}
public static void AddOrUpdate(EndPoint endPoint,
AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
_proxyEndpoints[GetKey(addressFamily, socketType, protocolType)] = endPoint;
}
public static void AddOrUpdate(IPAddress address, int port,
AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
_proxyEndpoints[GetKey(addressFamily, socketType, protocolType)] = new IPEndPoint(address, port);
}
public static void AddOrUpdate(string host, int port,
AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
_proxyEndpoints[GetKey(addressFamily, socketType, protocolType)] = new DnsEndPoint(host, port);
}
public static bool Remove(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
return _proxyEndpoints.Remove(GetKey(addressFamily, socketType, protocolType));
}
}
}

View File

@ -0,0 +1,15 @@
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy
{
public enum ReplyField : byte
{
Succeeded,
ServerFailure,
ConnectionNotAllowed,
NetworkUnreachable,
HostUnreachable,
ConnectionRefused,
TTLExpired,
CommandNotSupported,
AddressTypeNotSupported,
}
}

View File

@ -1,4 +1,4 @@
using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl;
using Ryujinx.HLE.HOS.Services.Ssl.Types;
using System;
@ -116,7 +116,22 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
public ResultCode Handshake(string hostName)
{
StartSslOperation();
_stream = new SslStream(new NetworkStream(((ManagedSocket)Socket).Socket, false), false, null, null);
NetworkStream networkStream;
if (Socket is ManagedProxySocket proxySocket)
{
networkStream = new NetworkStream(proxySocket.Socket, false);
}
else if (Socket is ManagedSocket managedSocket)
{
networkStream = new NetworkStream(managedSocket.Socket, false);
}
else
{
throw new NotSupportedException($"Socket of type {Socket.GetType()} does not support SSL.");
}
_stream = new SslStream(networkStream, false, null, null);
hostName = RetrieveHostName(hostName);
_stream.AuthenticateAsClient(hostName, null, TranslateSslVersion(_sslVersion), false);
EndSslOperation();

View File

@ -0,0 +1,68 @@
using NUnit.Framework;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Types;
using Ryujinx.Tests.HLE.HOS.Services.Sockets.Bsd.Proxy;
using System;
using System.Net;
using System.Net.Sockets;
namespace Ryujinx.Tests.HLE.HOS.Services.Sockets.Bsd.Impl
{
[TestFixture(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)]
[TestFixture(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)]
internal class ManagedProxySocketTestFixture(AddressFamily addressFamily, SocketType socketType,
ProtocolType protocolType)
{
private readonly IPEndPoint _serverEndPoint = new(IPAddress.Loopback, 0);
private readonly MockIpv4Socks5Server _server = new();
private ManagedProxySocket _proxySocket;
[OneTimeSetUp]
public void OneTimeSetUp()
{
_server.Start();
_serverEndPoint.Port = ((IPEndPoint)_server.Endpoint).Port;
}
[OneTimeTearDown]
public void OneTimeTearDown()
{
_server.DisconnectAll();
_server.Stop();
_server.Dispose();
}
[SetUp]
public void SetUp()
{
_proxySocket = new ManagedProxySocket(addressFamily, socketType, protocolType, _serverEndPoint);
}
[TearDown]
public void TearDown()
{
_proxySocket.Dispose();
_proxySocket = null;
}
[Test]
public void Bind()
{
LinuxError result = _proxySocket.Bind(new IPEndPoint(IPAddress.Loopback, 0));
bool dequeueResult = _server.MockSessions.TryDequeue(out Guid clientId);
Assert.AreEqual(LinuxError.SUCCESS, result);
Assert.True(dequeueResult);
MockIpv4Socks5NoAuthSession proxySession = (MockIpv4Socks5NoAuthSession)_server.FindSession(clientId);
Assert.AreEqual(0x05, proxySession.UsesVersion);
Assert.True(proxySession.IsAuthenticated);
Assert.True(proxySession.IsLastRequestValid, proxySession.RequestError);
Assert.AreNotEqual(0, proxySession.Command);
Assert.AreNotEqual(ProxyCommand.Connect, proxySession.Command);
Assert.NotNull(_proxySocket.LocalEndPoint);
}
}
}

View File

@ -0,0 +1,28 @@
using NUnit.Framework;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Auth;
using System;
namespace Ryujinx.Tests.HLE.HOS.Services.Sockets.Bsd.Proxy.Auth
{
public class AuthMethodTests
{
[Test]
public void GetAuth_ReturnValue([Values] AuthMethod authMethod)
{
// TODO: Remove this as soon as we have an implementation for these
if (authMethod is AuthMethod.UsernameAndPassword or AuthMethod.GSSAPI)
{
Assert.Throws<NotImplementedException>(() => authMethod.GetAuth());
return;
}
if (authMethod is AuthMethod.NoAcceptableMethods)
{
Assert.Throws<ArgumentException>(() => authMethod.GetAuth());
return;
}
Assert.IsInstanceOf<IProxyAuth>(authMethod.GetAuth());
}
}
}

View File

@ -0,0 +1,203 @@
using NetCoreServer;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Auth;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net;
namespace Ryujinx.Tests.HLE.HOS.Services.Sockets.Bsd.Proxy
{
internal class MockIpv4Socks5Server : TcpServer
{
public readonly ConcurrentQueue<Guid> MockSessions = new();
public MockIpv4Socks5Server() : base(IPAddress.Loopback, 0) { }
protected override TcpSession CreateSession()
{
var session = new MockIpv4Socks5NoAuthSession(this);
MockSessions.Enqueue(session.Id);
return session;
}
}
internal class MockIpv4Socks5NoAuthSession : TcpSession
{
public List<byte[]> Requests = new();
public byte[] Response;
public bool IsLastRequestValid;
public string RequestError;
public byte UsesVersion;
public bool IsAuthenticated;
public AuthMethod[] OfferedMethods;
public ProxyCommand Command;
public AddressType AddressType;
public IPAddress DestinationAddress;
public ushort DestinationPort;
public MockIpv4Socks5NoAuthSession(TcpServer server) : base(server) { }
protected override void OnReceived(byte[] buffer, long offset, long size)
{
Requests.Add(buffer);
if (size < 3)
{
IsLastRequestValid = false;
RequestError = $"Packet is too small. ({size} bytes)";
return;
}
UsesVersion = buffer[0];
if (!IsAuthenticated)
{
Authenticate(buffer, offset, size);
}
else
{
if (Command == 0)
{
ParseCommand(buffer, offset, size);
return;
}
if (Command != ProxyCommand.UdpAssociate && Response is { Length: > 0 })
{
IsLastRequestValid = true;
RequestError = string.Empty;
Send(Response);
Response = null;
}
else if (Command == ProxyCommand.UdpAssociate)
{
}
}
}
public void Reset()
{
Requests.Clear();
IsLastRequestValid = false;
RequestError = string.Empty;
UsesVersion = 0;
IsAuthenticated = false;
OfferedMethods = null;
Command = 0;
AddressType = 0;
DestinationAddress = null;
DestinationPort = 0;
}
private void SendReply(ReplyField replyCode, IPEndPoint boundEndpoint = null)
{
byte[] replyData =
{
// Version
0x05,
// Reply field
(byte)replyCode,
// Reserved
0x00,
// Address type: IPv4
0x01,
// Bound address
0x00, 0x00, 0x00, 0x00,
// Bound port
0x00, 0x00,
};
if (boundEndpoint != null)
{
boundEndpoint.Address.GetAddressBytes().CopyTo(replyData, 4);
BitConverter.GetBytes(boundEndpoint.Port).Reverse().ToArray().CopyTo(replyData, 8);
}
Send(replyData);
}
private void Authenticate(byte[] buffer, long offset, long size)
{
if (size > 2 + buffer[1])
{
IsLastRequestValid = false;
RequestError = $"Packet is too large. (Expected {2 + buffer[1]} bytes, got {size}.)";
return;
}
OfferedMethods = new AuthMethod[buffer[1]];
for (int i = 0; i < OfferedMethods.Length; i++)
{
OfferedMethods[i] = (AuthMethod)buffer[2 + i];
}
if (UsesVersion == 5 && OfferedMethods.Contains(AuthMethod.NoAuthenticationRequired))
{
IsLastRequestValid = true;
RequestError = string.Empty;
Send(new byte[]
{
// Version
0x05,
// Auth method
0x00,
});
IsAuthenticated = true;
}
else
{
IsLastRequestValid = false;
RequestError = $"Couldn't find {AuthMethod.NoAuthenticationRequired} in offered auth methods.";
Send(new byte[]
{
// Version
0x05,
// Auth method: No acceptable method
0xFF,
});
}
}
private void ParseCommand(byte[] buffer, long offset, long size)
{
if (size != 10)
{
IsLastRequestValid = false;
RequestError = $"Packet size is invalid. (Expected 10 bytes, got {size}.)";
SendReply(ReplyField.ServerFailure);
}
Command = (ProxyCommand)buffer[1];
if (buffer[2] != 0x00)
{
IsLastRequestValid = false;
RequestError = $"Reserved must be 0x00. (actual value: 0x{buffer[2]:x})";
SendReply(ReplyField.ServerFailure);
return;
}
if (buffer[3] != 0x01)
{
IsLastRequestValid = false;
RequestError = $"AddressType must be 0x01. (actual value: 0x{buffer[3]:x})";
SendReply(ReplyField.AddressTypeNotSupported);
return;
}
DestinationAddress = new IPAddress(buffer[4..8]);
DestinationPort = BitConverter.ToUInt16(buffer, 8);
IsLastRequestValid = true;
RequestError = string.Empty;
SendReply(ReplyField.Succeeded);
}
}
}

View File

@ -0,0 +1,46 @@
using NUnit.Framework;
using Ryujinx.Common.Memory;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Auth;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Packets;
using System.Runtime.InteropServices;
namespace Ryujinx.Tests.HLE.HOS.Services.Sockets.Bsd.Proxy.Packets
{
public class MethodSelectionTests
{
[Test]
public void MethodSelectionRequest1_Size()
{
// Version: 1 byte
// Number of methods: 1 byte
// Methods: 1 - 255 bytes (in this case: 1)
// Total: 3 bytes
var request = new MethodSelectionRequest1
{
Version = ProxyConsts.Version,
NumOfMethods = 1,
Methods = new Array1<AuthMethod> { [0] = AuthMethod.NoAuthenticationRequired },
};
Assert.AreEqual(3, Marshal.SizeOf(request));
}
[Test]
public void MethodSelectionResponse_Size()
{
// Version: 1 byte
// Method: 1 byte
// Total: 2 bytes
var response = new MethodSelectionResponse()
{
Version = ProxyConsts.Version,
Method = AuthMethod.NoAuthenticationRequired,
};
Assert.AreEqual(2, Marshal.SizeOf(response));
}
}
}

View File

@ -0,0 +1,129 @@
using NUnit.Framework;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy;
using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Proxy.Packets;
using System;
using System.Net;
using System.Runtime.InteropServices;
namespace Ryujinx.Tests.HLE.HOS.Services.Sockets.Bsd.Proxy.Packets
{
public class SocksIpv4Tests
{
[Test]
public void Request_Size()
{
// Version: 1 byte
// Command: 1 byte
// Reserved: 1 byte
// Address type: 1 byte
// IPv4 address: 4 bytes
// Port: 2 bytes
// Total: 10 bytes
var request = new SocksIpv4Request
{
Version = ProxyConsts.Version,
Reserved = 0x00,
Command = ProxyCommand.Connect,
AddressType = AddressType.Ipv4Address,
DestinationAddress = IPAddress.Any,
DestinationPort = 0,
};
Assert.AreEqual(10, Marshal.SizeOf(request));
}
[Test]
public void Response_Size()
{
// Version: 1 byte
// Reply: 1 byte
// Reserved: 1 byte
// Address type: 1 byte
// IPv4 address: 4 bytes
// Port: 2 bytes
// Total: 10 bytes
var response = new SocksIpv4Response
{
Version = ProxyConsts.Version,
Reserved = 0x00,
ReplyField = ReplyField.Succeeded,
AddressType = AddressType.Ipv4Address,
BoundAddress = IPAddress.Any,
BoundPort = 0,
};
Assert.AreEqual(10, Marshal.SizeOf(response));
}
[Test]
public void UdpHeader_Size()
{
// Reserved: 2 bytes
// Fragment: 1 byte
// Address type: 1 byte
// IPv4 address: 4 bytes
// Port: 2 bytes
// Total: 10 bytes
var header = new SocksIpv4UdpHeader
{
Reserved = 0x0000,
Fragment = 0,
AddressType = AddressType.Ipv4Address,
DestinationAddress = IPAddress.Any,
DestinationPort = 0,
};
Assert.AreEqual(10, Marshal.SizeOf(header));
}
[Test, Sequential]
public void Port_ByteOrder(
[Values((ushort)443, (ushort)2127, (ushort)22)] ushort port,
[Values((ushort)47873, (ushort)20232, (ushort)5632)] ushort expected)
{
var request = new SocksIpv4Request
{
Version = ProxyConsts.Version,
Reserved = 0x00,
Command = ProxyCommand.Connect,
AddressType = AddressType.Ipv4Address,
DestinationAddress = IPAddress.Any,
DestinationPort = port,
};
var response = new SocksIpv4Response
{
Version = ProxyConsts.Version,
Reserved = 0x00,
ReplyField = ReplyField.Succeeded,
AddressType = AddressType.Ipv4Address,
BoundAddress = IPAddress.Any,
BoundPort = port,
};
var header = new SocksIpv4UdpHeader
{
Reserved = 0x0000,
Fragment = 0,
AddressType = AddressType.Ipv4Address,
DestinationAddress = IPAddress.Any,
DestinationPort = port,
};
byte[] requestData = new byte[10];
byte[] responseData = new byte[10];
byte[] headerData = new byte[10];
MemoryMarshal.Write(requestData, request);
MemoryMarshal.Write(responseData, response);
MemoryMarshal.Write(headerData, header);
Assert.AreEqual(expected, BitConverter.ToUInt16(requestData, 8));
Assert.AreEqual(expected, BitConverter.ToUInt16(responseData, 8));
Assert.AreEqual(expected, BitConverter.ToUInt16(headerData, 8));
}
}
}