记录一下 简单udp和sni 代理 done

由于之前借鉴 Kestrel 了非常多抽象和优化实现,对于后续的扩展非常便利,

实现 简单udp和sni 代理 两个功能比预期快了超多(当然也有偷懒因素)

(PS 大家有空的话,能否在 GitHub https://github.com/fs7744/NZOrz 点个 star 呢?毕竟借鉴代码也不易呀 哈哈哈哈哈)

简单udp代理

这里的udp 代理功能比较简单:代理程序收到任何 udp 包都会通过路由匹配找 upstream ,然后转发给upstream

udp proxy 使用配置

基本格式和之前 tcp proxy 一致,

只是Protocols得选择UDP, 然后多了UdpResponses 允许 upstream 返回多少个 udp 包给请求者, 默认为0,即不返回任何包

{   "Logging": {     "LogLevel": {       "Default": "Information"     }   },   "ReverseProxy": {     "Routes": {       "udpTest": {         "Protocols": [ "UDP" ],         "Match": {           "Hosts": [ "*:5000" ]         },         "ClusterId": "udpTest",         "RetryCount": 1,         "UdpResponses": 1,         "Timeout": "00:00:11"       }     },     "Clusters": {       "udpTest": {         "LoadBalancingPolicy": "RoundRobin",         "HealthCheck": {           "Passive": {             "Enable": true           }         },         "Destinations": [           {             "Address": "127.0.0.1:11000"           }         ]       }     }   } } 

实现

这里列举一下,表明有多简单

ps: 由于要实现的是非常简单udp代理,所以不基于IMultiplexedConnectionListener ,而基于 IConnectionListener 方式 (对,就是俺偷懒了)

1. 实现 UdpConnectionContext

偷懒就直接把udp 包数据放 context 上了,不放 Parameters 上,减少字典实例和内存使用

public sealed class UdpConnectionContext : TransportConnection {     private readonly IMemoryOwner<byte> memory;     public Socket Socket { get; }     public int ReceivedBytesCount { get; }      public Memory<byte> ReceivedBytes => memory.Memory.Slice(0, ReceivedBytesCount);      public UdpConnectionContext(Socket socket, UdpReceiveFromResult result)     {         Socket = socket;         ReceivedBytesCount = result.ReceivedBytesCount;         this.memory = result.Buffer;         LocalEndPoint = socket.LocalEndPoint;         RemoteEndPoint = result.RemoteEndPoint;     }      public UdpConnectionContext(Socket socket, EndPoint remoteEndPoint, int receivedBytes, IMemoryOwner<byte> memory)     {         Socket = socket;         ReceivedBytesCount = receivedBytes;         this.memory = memory;         LocalEndPoint = socket.LocalEndPoint;         RemoteEndPoint = remoteEndPoint;     }      public override ValueTask DisposeAsync()     {         memory.Dispose();         return default;     } } 
2. 实现 IConnectionListener
internal sealed class UdpConnectionListener : IConnectionListener {     private EndPoint? udpEndPoint;     private readonly GatewayProtocols protocols;     private OrzLogger _logger;     private readonly IUdpConnectionFactory connectionFactory;     private readonly Func<EndPoint, GatewayProtocols, Socket> createBoundListenSocket;     private Socket? _listenSocket;      public UdpConnectionListener(EndPoint? udpEndPoint, GatewayProtocols protocols, IRouteContractor contractor, OrzLogger logger, IUdpConnectionFactory connectionFactory)     {         this.udpEndPoint = udpEndPoint;         this.protocols = protocols;         _logger = logger;         this.connectionFactory = connectionFactory;         createBoundListenSocket = contractor.GetSocketTransportOptions().CreateBoundListenSocket;     }      public EndPoint EndPoint => udpEndPoint;      internal void Bind()     {         if (_listenSocket != null)         {             throw new InvalidOperationException("Transport is already bound.");         }          Socket listenSocket;         try         {             listenSocket = createBoundListenSocket(EndPoint, protocols);         }         catch (SocketException e) when (e.SocketErrorCode == SocketError.AddressAlreadyInUse)         {             throw new AddressInUseException(e.Message, e);         }          Debug.Assert(listenSocket.LocalEndPoint != null);          _listenSocket = listenSocket;     }      public async ValueTask<ConnectionContext?> AcceptAsync(CancellationToken cancellationToken = default)     {         while (true)         {             try             {                 Debug.Assert(_listenSocket != null, "Bind must be called first.");                 var r = await connectionFactory.ReceiveAsync(_listenSocket, cancellationToken);                 return new UdpConnectionContext(_listenSocket, r);             }             catch (ObjectDisposedException)             {                 // A call was made to UnbindAsync/DisposeAsync just return null which signals we're done                 return null;             }             catch (SocketException e) when (e.SocketErrorCode == SocketError.OperationAborted)             {                 // A call was made to UnbindAsync/DisposeAsync just return null which signals we're done                 return null;             }             catch (SocketException)             {                 // The connection got reset while it was in the backlog, so we try again.                 _logger.ConnectionReset("(null)");             }         }     }      public ValueTask DisposeAsync()     {         _listenSocket?.Dispose();          return default;     }      public ValueTask UnbindAsync(CancellationToken cancellationToken = default)     {         _listenSocket?.Dispose();         return default;     } } 
3. 实现 IConnectionListenerFactory
public sealed class UdpTransportFactory : IConnectionListenerFactory, IConnectionListenerFactorySelector {     private readonly IRouteContractor contractor;     private readonly OrzLogger logger;     private readonly IUdpConnectionFactory connectionFactory;      public UdpTransportFactory(         IRouteContractor contractor,         OrzLogger logger,         IUdpConnectionFactory connectionFactory)     {         ArgumentNullException.ThrowIfNull(contractor);         ArgumentNullException.ThrowIfNull(logger);          this.contractor = contractor;         this.logger = logger;         this.connectionFactory = connectionFactory;     }      public ValueTask<IConnectionListener> BindAsync(EndPoint endpoint, GatewayProtocols protocols, CancellationToken cancellationToken = default)     {         var transport = new UdpConnectionListener(endpoint, GatewayProtocols.UDP, contractor, logger, connectionFactory);         transport.Bind();         return new ValueTask<IConnectionListener>(transport);     }      public bool CanBind(EndPoint endpoint, GatewayProtocols protocols)     {         if (!protocols.HasFlag(GatewayProtocols.UDP)) return false;         return endpoint switch         {             IPEndPoint _ => true,             _ => false         };     } } 
4. 在 L4ProxyMiddleware 实现udp proxy 具体逻辑

路由和之前tcp的公用,这里就不列举了

public class L4ProxyMiddleware : IOrderMiddleware {         public async Task Invoke(ConnectionContext context, ConnectionDelegate next)     {         try         {             if (context.Protocols == GatewayProtocols.SNI)             {                 await SNIProxyAsync(context);             }             else             {                 var route = await router.MatchAsync(context);                 if (route is null)                 {                     logger.NotFoundRouteL4(context.LocalEndPoint);                 }                 else                 {                     context.Route = route;                     logger.ProxyBegin(route.RouteId);                     if (context.Protocols == GatewayProtocols.TCP)                     {                         await TcpProxyAsync(context, route);                     }                     else                     {                         await UdpProxyAsync((UdpConnectionContext)context, route);                     }                     logger.ProxyEnd(route.RouteId);                 }             }         }         catch (Exception ex)         {             logger.UnexpectedException(ex.Message, ex);         }         finally         {             await next(context);         }     }      private async Task UdpProxyAsync(UdpConnectionContext context, RouteConfig route)     {         try         {             var socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);             var cts = route.CreateTimeoutTokenSource(cancellationTokenSourcePool);             var token = cts.Token;             if (await DoUdpSendToAsync(socket, context, route, route.RetryCount, await reqUdp(context, context.ReceivedBytes, token), token))             {                 var c = route.UdpResponses;                 while (c > 0)                 {                     var r = await udp.ReceiveAsync(socket, token);                     c--;                     await udp.SendToAsync(context.Socket, context.RemoteEndPoint, await respUdp(context, r.GetReceivedBytes(), token), token);                 }             }             else             {                 logger.NotFoundAvailableUpstream(route.ClusterId);             }         }         catch (OperationCanceledException)         {             logger.ConnectUpstreamTimeout(route.RouteId);         }         catch (Exception ex)         {             logger.UnexpectedException(nameof(UdpProxyAsync), ex);         }         finally         {             context.SelectedDestination?.ConcurrencyCounter.Decrement();         }     } 

所以是不是真的简单, 理论上基于 Kestrel 也是一个样子哦

优化

当然参考于 Kestrel 的 tcp socket 处理,也是有些简单优化的, 比如

  • 不使用 UdpClient (ps 不是因为实现烂哈,而是其比较公用,没有机会让我们改变里面的内容)
  • 基于 SocketAsyncEventArgs, IValueTaskSource<SocketReceiveFromResult>SocketAsyncEventArgs, IValueTaskSource<int> 实现 将异步读写交予 PipeScheduler 的逻辑
  • 基于 ConcurrentQueue<UdpSender> 实现简单的 udp发送对象池,加强对象复用,稍稍稍微减少内存占用
  • 基于 ConcurrentQueue<PooledCancellationTokenSource> 实现简单的 CancellationTokenSource对象池,加强对象复用,稍稍稍微减少内存占用

sni代理

除了 tcp 和 udp 的基本代理, 也尝试实现了一个 对tcp的 sni 代理,(比如 http1 和 http2 的 https)

不过目前只实现了代理不做ssl加密解密,upstream自己处理的pass 模式,如果代理要实现ssl加密解密,理论上基于现成的 sslstream

sni proxy 使用配置

只需配置Listen 中 公用的 sni 监听端口

然后不同 sni 配置自己的路由和upstream就好

同时每个route 可以通过SupportSslProtocols限制 tls 版本

举个栗子

{   "Logging": {     "LogLevel": {       "Default": "Information"     }   },   "ReverseProxy": {     "Listen": {       "snitest": {         "Protocols": "SNI",         "Address": [ "*:444" ]       }     },     "Routes": {       "snitestroute": {         "Protocols": "SNI",         "SupportSslProtocols": [ "Tls13", "Tls12" ],         "Match": {           "Hosts": [ "*google.com" ]         },         "ClusterId": "apidemo"       },       "snitestroute2": {         "Protocols": "Tcp",         "Match": {           "Hosts": [ "*:448" ]         },         "ClusterId": "apidemo"       }     },     "Clusters": {       "apidemo": {         "LoadBalancingPolicy": "RoundRobin",         "HealthCheck": {           "Active": {             "Enable": true,             "Policy": "Connect"           }         },         "Destinations": [           {             "Address": "https://www.google.com"           }         ]       }     }   } } 

实现

核心实现其实只有 路由 处理 ,proxy 代理和 tcp 代理一模一样(在请求 和 upstream 间搬运 tcp数据而已)

路由处理

通过 ClientHello 找到要访问的 域名, 然后通过域名匹配路由找到 upstream, 最后搬运 tcp数据

ClientHello 解析就直接搬运自TlsFrameHelper

    /// 路由匹配     public async ValueTask<(RouteConfig, ReadResult)> MatchSNIAsync(ConnectionContext context, CancellationToken token)     {         if (sniRoute is null) return (null, default);         var (hello, rr) = await TryGetClientHelloAsync(context, token);         if (hello.HasValue)         {             var h = hello.Value;             var r = await sniRoute.MatchAsync(h.TargetName.Reverse(), h, MatchSNI);             if (r is null)             {                 logger.NotFoundRouteSni(h.TargetName);             }             return (r, rr);         }         else         {             logger.NotFoundRouteSni("client hello failed");             return (null, rr);         }     }      /// 匹配 tls 版本     private bool MatchSNI(RouteConfig config, TlsFrameInfo info)     {         if (!config.SupportSslProtocols.HasValue) return true;         var v = config.SupportSslProtocols.Value;         if (v == SslProtocols.None) return true;         var t = info.SupportedVersions;         if (v.HasFlag(SslProtocols.Tls13) && t.HasFlag(SslProtocols.Tls13)) return true;         else if (v.HasFlag(SslProtocols.Tls12) && t.HasFlag(SslProtocols.Tls12)) return true;         else if (v.HasFlag(SslProtocols.Tls11) && t.HasFlag(SslProtocols.Tls11)) return true;         else if (v.HasFlag(SslProtocols.Tls) && t.HasFlag(SslProtocols.Tls)) return true;         else if (v.HasFlag(SslProtocols.Ssl3) && t.HasFlag(SslProtocols.Ssl3)) return true;         else if (v.HasFlag(SslProtocols.Ssl2) && t.HasFlag(SslProtocols.Ssl2)) return true;         else if (v.HasFlag(SslProtocols.Default) && t.HasFlag(SslProtocols.Default)) return true;         else return false;     }      /// 解析 ClientHello     private static async ValueTask<(TlsFrameInfo?, ReadResult)> TryGetClientHelloAsync(ConnectionContext context, CancellationToken token)     {         var input = context.Transport.Input;         TlsFrameInfo info = default;         while (true)         {             var f = await input.ReadAsync(token).ConfigureAwait(false);             if (f.IsCompleted)             {                 return (null, f);             }             var buffer = f.Buffer;             if (buffer.Length == 0)             {                 continue;             }              var data = buffer.IsSingleSegment ? buffer.First.Span : buffer.ToArray();             if (TlsFrameHelper.TryGetFrameInfo(data, ref info))             {                 return (info, f);             }             else             {                 input.AdvanceTo(buffer.Start, buffer.End);                 continue;             }         }     } 
搬运 tcp数据
private async Task SNIProxyAsync(ConnectionContext context) {     var c = cancellationTokenSourcePool.Rent();     c.CancelAfter(options.ConnectionTimeout);     var (route, r) = await router.MatchSNIAsync(context, c.Token);     if (route is not null)     {         context.Route = route;         logger.ProxyBegin(route.RouteId);         ConnectionContext upstream = null;         try         {             upstream = await DoConnectionAsync(context, route, route.RetryCount);             if (upstream is null)             {                 logger.NotFoundAvailableUpstream(route.ClusterId);             }             else             {                 context.SelectedDestination?.ConcurrencyCounter.Increment();                 var cts = route.CreateTimeoutTokenSource(cancellationTokenSourcePool);                 var t = cts.Token;                 await r.CopyToAsync(upstream.Transport.Output, t); // 和tcp 代理搬运数据唯一不同, 要先发送 ClientHello 数据,因为已经被我们读取了                 context.Transport.Input.AdvanceTo(r.Buffer.End);                 var task = hasMiddlewareTcp ?                         await Task.WhenAny(                         context.Transport.Input.CopyToAsync(new MiddlewarePipeWriter(upstream.Transport.Output, context, reqTcp), t)                         , upstream.Transport.Input.CopyToAsync(new MiddlewarePipeWriter(context.Transport.Output, context, respTcp), t))                         : await Task.WhenAny(                         context.Transport.Input.CopyToAsync(upstream.Transport.Output, t)                         , upstream.Transport.Input.CopyToAsync(context.Transport.Output, t));                 if (task.IsCanceled)                 {                     logger.ProxyTimeout(route.RouteId, route.Timeout);                 }             }         }         catch (OperationCanceledException)         {             logger.ConnectUpstreamTimeout(route.RouteId);         }         catch (Exception ex)         {             logger.UnexpectedException(nameof(TcpProxyAsync), ex);         }         finally         {             context.SelectedDestination?.ConcurrencyCounter.Decrement();             upstream?.Abort();         }         logger.ProxyEnd(route.RouteId);     } } 

组件各部分都是可替换或者可增加的

因为整体都是基于ioc的,所以组件各部分都是可替换或者可增加的, 客制化扩展还是很高的哦

目前暴露的列表可在 代码这里面查看

internal static HostApplicationBuilder UseOrzDefaults(this HostApplicationBuilder builder) {     var services = builder.Services;     services.AddSingleton<IHostedService, HostedService>();     services.AddSingleton(TimeProvider.System);     services.AddSingleton<IMeterFactory, DummyMeterFactory>();     services.AddSingleton<IServer, OrzServer>();     services.AddSingleton<OrzLogger>();     services.AddSingleton<OrzMetrics>();     services.AddSingleton<IConnectionListenerFactory, SocketTransportFactory>();     services.AddSingleton<IConnectionListenerFactory, UdpTransportFactory>();     services.AddSingleton<IUdpConnectionFactory, UdpConnectionFactory>();     services.AddSingleton<IConnectionFactory, SocketConnectionFactory>();     services.AddSingleton<IRouteContractorValidator, RouteContractorValidator>();     services.AddSingleton<IEndPointConvertor, CommonEndPointConvertor>();     services.AddSingleton<IL4Router, L4Router>();     services.AddSingleton<IOrderMiddleware, L4ProxyMiddleware>();     services.AddSingleton<ILoadBalancingPolicyFactory, LoadBalancingPolicy>();     services.AddSingleton<IClusterConfigValidator, ClusterConfigValidator>();     services.AddSingleton<IDestinationResolver, DnsDestinationResolver>();      services.AddSingleton<ILoadBalancingPolicy, RandomLoadBalancingPolicy>();     services.AddSingleton<ILoadBalancingPolicy, RoundRobinLoadBalancingPolicy>();     services.AddSingleton<ILoadBalancingPolicy, LeastRequestsLoadBalancingPolicy>();     services.AddSingleton<ILoadBalancingPolicy, PowerOfTwoChoicesLoadBalancingPolicy>();      services.AddSingleton<IHealthReporter, PassiveHealthReporter>();     services.AddSingleton<IHealthUpdater, HealthyAndUnknownDestinationsUpdater>();     services.AddSingleton<IActiveHealthCheckMonitor, ActiveHealthCheckMonitor>();     services.AddSingleton<IActiveHealthChecker, ConnectionActiveHealthChecker>();      return builder; } 

比如要添加 负载均衡策略,就可以实现

public interface ILoadBalancingPolicy {     string Name { get; }      DestinationState? PickDestination(ConnectionContext context, IReadOnlyList<DestinationState> availableDestinations); } 

如果对全部已有负载均衡策略都不满意,那就可以直接替换 ILoadBalancingPolicyFactory

public interface ILoadBalancingPolicyFactory {     DestinationState? PickDestination(ConnectionContext context, RouteConfig route); } 

比如你就可以通过sni将开发环境(或者其他环境)无法访问的请求在一台有其他访问权限的机器进行转发

差不多就做了这些,造轮子还是挺好玩的,当然大家如果在 GitHub https://github.com/fs7744/NZOrz 点个 star, 就更好玩了

发表评论

评论已关闭。

相关文章