using System; using System.Collections; using System.Collections.Generic; using System.Net; using System.Net.Sockets; using System.Net.NetworkInformation; using System.Threading; abstract class Listener { public Socket Socket; public virtual void Close(){ if ( Socket != null ) Socket.Close(); } public abstract void SendResponse( byte [] Packet, Header H ); protected static readonly Set TcpListeners = new Set(); private static List Listeners = new List(); public static void Start() { // For UDP, RFC2181 section 4 requires that the response is sent using same IP // address on which request was received. NetworkInterface[] Nics = NetworkInterface.GetAllNetworkInterfaces(); foreach ( NetworkInterface Nic in Nics ) { bool RxOnly = false; try { RxOnly = Nic.IsReceiveOnly; } catch (Exception){} if ( !RxOnly ) { IPInterfaceProperties Props = Nic.GetIPProperties(); UnicastIPAddressInformationCollection U = Props.UnicastAddresses; foreach ( UnicastIPAddressInformation UIA in U ) { Listeners.Add( new UdpListener( 53, UIA.Address ) ); // Listeners.Add( new QrpListener( Config.QRP_Port, UIA.Address ) ); } } } for ( int i=0; i<2; i+=1 ) { bool IPv6 = i==1; Listeners.Add( new TcpListener( 53, IPv6 ) ); } } public static void CheckTimeout( long Now ) { if ( TcpListeners.Count > 0 ) { foreach ( TcpListener tc in TcpListeners ) if ( Now > tc.TimeoutEnd ) tc.Close(); } } public static void Stop() { foreach ( Listener x in Listeners ) x.Close(); foreach ( Listener x in TcpListeners ) x.Close(); } } /* The Header class is used to record transport information about an incoming DNS request that is needed to send the response. */ abstract class Header { public Listener L; public abstract EndPoint EP { get; } public abstract void Read( DnsRx Rx ); public abstract int Write( DnsTx Tx ); // Result is number of bytes to skip after writing header public abstract TP TP{ get; } public virtual bool HasPayload{ get { return true; } } public virtual void SendResponse( byte [] Packet ){ L.SendResponse( Packet, this ); } public virtual void SendBytes( byte [] Response ){} public virtual byte [] GetPublicKey(){ return null; } #if (Trace) public override String ToString(){ return TP + " " + EP.ToString(); } #endif } abstract class Header16 : Header { protected UInt16 ID; public override void Read( DnsRx Rx ) { ID = Rx.Read16(); } public override int Write( DnsTx Tx ) { Tx.Put16( ID ); return 0; } } ////////////////////////////////////////////////////////////// TCP sealed class TcpHeader : Header16 { public override TP TP{ get{ return TP.TCP; } } public override EndPoint EP { get{ return L.Socket.RemoteEndPoint; } } public override int Write( DnsTx Tx ) { Tx.Put16( 0 ); // Reserve space for length Tx.Put16( ID ); return 0; } } sealed class TcpListener : Listener { private int PacketLength; private DnsRx Rx; public long TimeoutEnd; Queue SendQ = new Queue(); private bool Sending; public override void SendResponse( byte [] Packet, Header H ) { SendQ.Enqueue( Packet ); // Would it be better to enqueue E ? SendNext(); } private void SendNext() { if ( Sending || SendQ.Count == 0 ) return; Sending = true; byte [] Packet = SendQ.Dequeue(); TimeoutEnd = Cache.Ticks + 10000L * Config.TcpServerTimeout; int Length = Packet.Length - 2; Packet[0] = (byte) ( Length / 256 ); Packet[1] = (byte) ( Length % 256 ); Socket.BeginSend( Packet, 0, Packet.Length, 0, new AsyncCallback(SendCallback), this ); } private static void SendCallback( IAsyncResult ar ) { try { TcpListener rc = (TcpListener)ar.AsyncState; int N = rc.Socket.EndSend( ar ); // ToDo : Check N is as expected? rc.Sending = false; rc.SendNext(); } catch ( Exception e ) { Cache.LogError( e ); } } private void Listen() { Socket.Listen(1000); // Queue size : not sure what is appropriate (ToDo) Socket.BeginAccept( new AsyncCallback(AcceptCallback), this ); } public static void AcceptCallback( IAsyncResult ar ) { if ( Cache.Stopping ) return; TcpListener rc = (TcpListener)ar.AsyncState; try { Socket s = rc.Socket.EndAccept( ar ); new TcpListener( s ); } catch ( Exception e ) { Cache.LogError( e ); } if ( rc.Socket != null ) // Not shutting down rc.Socket.BeginAccept( new AsyncCallback(AcceptCallback), rc ); } public TcpListener( Socket s ) { Socket = s; Rx = new DnsRx(2); TimeoutEnd = Cache.Ticks + 10000L * Config.TcpServerTimeout; TcpListener x = TcpListeners[ this ]; Receive(); } private void Receive() { int Length = PacketLength == 0 ? 2 : Rx.Buffer.Length - Rx.N; Socket.BeginReceive( Rx.Buffer, Rx.N, Length, 0/* Socket flags */, new AsyncCallback(ReceiveCallback), this ); } private static void ReceiveCallback( IAsyncResult ar ) { TcpListener rc = (TcpListener)ar.AsyncState; rc.Respond( ar ); } private void Respond( IAsyncResult ar ) { if ( Socket == null ) return; try { Rx.N += Socket.EndReceive( ar ); #if (Trace) // Cache.Log( "TCP Rx.N=" + Rx.N + " PacketLength=" + PacketLength ); #endif if ( Rx.N == 0 ) { // Do nothing } else if ( PacketLength == 0 ) { Rx.Ix = 0; PacketLength = Rx.Read16(); // ToDo : should check PacketLength not excessive ( DoS attack ) Rx.Buffer = new byte[ PacketLength ]; Rx.N = 0; Receive(); } else if ( Rx.N == PacketLength ) { TcpHeader H = new TcpHeader(); H.L = this; Rx.H = H; Cache.Respond( Rx ); PacketLength = 0; Rx = new DnsRx(2); Receive(); } else { Receive(); } } catch ( Exception e ) { Cache.LogError( e ); Close(); return; } } public override void Close() { if ( Socket != null ) { Socket S = Socket; Socket = null; S.Close(); if ( TimeoutEnd != 0 && !Cache.Stopping ) TcpListeners.Remove(this); } } public TcpListener( int Port, bool IPv6 ) { bool Ok = false; try { Socket = new Socket ( IPv6 ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp ); EndPoint Local = new IPEndPoint( IPv6 ? IPAddress.IPv6Any : IPAddress.Any, Port); Socket.Bind(Local); Listen(); if (IPv6) Cache.IPv6Available = true; Ok = true; } catch ( Exception ) { Close(); } #if (Trace) Cache.Log( "TCP Listener IPv6=" + IPv6 + " Port=" + Port + " Ok=" + Ok ); #endif } } ////////////////////////////////////////////////////////////// UDP sealed class UdpHeader : Header16 { public EndPoint REP; public override TP TP{ get{ return TP.UDP; } } public override EndPoint EP { get{ return REP; } } public override void SendBytes( byte [] Response ) { L.Socket.SendTo( Response, REP ); } public UdpHeader( bool IPv6 ) { REP = new IPEndPoint( IPv6 ? IPAddress.IPv6Any : IPAddress.Any, 0 ); } } class UdpListener : Listener { protected bool IPv6; protected virtual UdpHeader GetHeader() { return new UdpHeader( IPv6 ); } public override void SendResponse( byte [] Packet, Header H ) { H.SendBytes( Packet ); } // Note : use a thread because BeginReceiveFrom has problems picking up end point. protected void Listen() { while ( !Cache.Stopping ) try { UdpHeader H = GetHeader(); H.L = this; DnsRx Rx = new DnsRx( Config.EdnsQueryLimit ); Rx.H = H; Rx.N = Socket.ReceiveFrom ( Rx.Buffer, ref H.REP ); ThreadPool.QueueUserWorkItem( new WaitCallback(Cache.Respond), Rx ); } catch ( Exception e ) { Cache.LogError( e ); } } public UdpListener( int Port, IPAddress Address ) { this.IPv6 = ( Address.AddressFamily == AddressFamily.InterNetworkV6 ); bool Ok = false; try { Socket = new Socket ( IPv6 ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp ); EndPoint Local = new IPEndPoint( Address, Port); Socket.Bind( Local ); Thread LT = new Thread( new ThreadStart(Listen) ); LT.Priority = ThreadPriority.Highest; LT.Start(); Ok = true; } catch ( Exception ) { Close(); } #if (Trace) Cache.Log( "UDP Listener Address=" + Address + " Port=" + Port + " Ok=" + Ok ); #endif } }