resolver.py :  » Web-Frameworks » Zope » Zope-2.6.0 » ZServer » medusa » Python Open Source

Home
Python Open Source
1.3.1.2 Python
2.Ajax
3.Aspect Oriented
4.Blog
5.Build
6.Business Application
7.Chart Report
8.Content Management Systems
9.Cryptographic
10.Database
11.Development
12.Editor
13.Email
14.ERP
15.Game 2D 3D
16.GIS
17.GUI
18.IDE
19.Installer
20.IRC
21.Issue Tracker
22.Language Interface
23.Log
24.Math
25.Media Sound Audio
26.Mobile
27.Network
28.Parser
29.PDF
30.Project Management
31.RSS
32.Search
33.Security
34.Template Engines
35.Test
36.UML
37.USB Serial
38.Web Frameworks
39.Web Server
40.Web Services
41.Web Unit
42.Wiki
43.Windows
44.XML
Python Open Source » Web Frameworks » Zope 
Zope » Zope 2.6.0 » ZServer » medusa » resolver.py
# -*- Mode: Python; tab-width: 4 -*-

#
#  Author: Sam Rushing <rushing@nightmare.com>
#

RCS_ID =  '$Id: resolver.py,v 1.10 2002/03/21 15:48:53 htrd Exp $'


# Fast, low-overhead asynchronous name resolver.  uses 'pre-cooked'
# DNS requests, unpacks only as much as it needs of the reply.

# see rfc1035 for details

import string
import asyncore
import socket
import sys
import time
from counter import counter

if RCS_ID.startswith('$Id: '):
    VERSION = string.split(RCS_ID)[2]
else:
    VERSION = '0.0'

# header
#                                    1  1  1  1  1  1
#      0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                      ID                       |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                    QDCOUNT                    |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                    ANCOUNT                    |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                    NSCOUNT                    |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                    ARCOUNT                    |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+


# question
#                                    1  1  1  1  1  1
#      0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                                               |
#    /                     QNAME                     /
#    /                                               /
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                     QTYPE                     |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
#    |                     QCLASS                    |
#    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+

# build a DNS address request, _quickly_
def fast_address_request (host, id=0):
    return (
            '%c%c' % (chr((id>>8)&0xff),chr(id&0xff))
            + '\001\000\000\001\000\000\000\000\000\000%s\000\000\001\000\001' % (
                    string.join (
                            map (
                                    lambda part: '%c%s' % (chr(len(part)),part),
                                    string.split (host, '.')
                                    ), ''
                            )
                    )
            )
    
def fast_ptr_request (host, id=0):
    return (
            '%c%c' % (chr((id>>8)&0xff),chr(id&0xff))
            + '\001\000\000\001\000\000\000\000\000\000%s\000\000\014\000\001' % (
                    string.join (
                            map (
                                    lambda part: '%c%s' % (chr(len(part)),part),
                                    string.split (host, '.')
                                    ), ''
                            )
                    )
            )
    
def unpack_name (r,pos):
    n = []
    while 1:
        ll = ord(r[pos])
        if (ll&0xc0):
                # compression
            pos = (ll&0x3f << 8) + (ord(r[pos+1]))
        elif ll == 0:
            break      
        else:
            pos = pos + 1
            n.append (r[pos:pos+ll])
            pos = pos + ll
    return string.join (n,'.')
    
def skip_name (r,pos):
    s = pos
    while 1:
        ll = ord(r[pos])
        if (ll&0xc0):
                # compression
            return pos + 2
        elif ll == 0:
            pos = pos + 1
            break
        else:
            pos = pos + ll + 1
    return pos
    
def unpack_ttl (r,pos):
    return reduce (
            lambda x,y: (x<<8)|y,
            map (ord, r[pos:pos+4])
            )
    
    # resource record
    #                                    1  1  1  1  1  1
    #      0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
    #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    #    |                                               |
    #    /                                               /
    #    /                      NAME                     /
    #    |                                               |
    #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    #    |                      TYPE                     |
    #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    #    |                     CLASS                     |
    #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    #    |                      TTL                      |
    #    |                                               |
    #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    #    |                   RDLENGTH                    |
    #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
    #    /                     RDATA                     /
    #    /                                               /
    #    +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
    
def unpack_address_reply (r):
    ancount = (ord(r[6])<<8) + (ord(r[7]))
    # skip question, first name starts at 12,
    # this is followed by QTYPE and QCLASS
    pos = skip_name (r, 12) + 4
    if ancount:
            # we are looking very specifically for
            # an answer with TYPE=A, CLASS=IN (\000\001\000\001)
        for an in range(ancount):
            pos = skip_name (r, pos)
            if r[pos:pos+4] == '\000\001\000\001':
                return (
                        unpack_ttl (r,pos+4),
                        '%d.%d.%d.%d' % tuple(map(ord,r[pos+10:pos+14]))
                        )
                # skip over TYPE, CLASS, TTL, RDLENGTH, RDATA
            pos = pos + 8
            rdlength = (ord(r[pos])<<8) + (ord(r[pos+1]))
            pos = pos + 2 + rdlength
        return 0, None
    else:
        return 0, None
        
def unpack_ptr_reply (r):
    ancount = (ord(r[6])<<8) + (ord(r[7]))
    # skip question, first name starts at 12,
    # this is followed by QTYPE and QCLASS
    pos = skip_name (r, 12) + 4
    if ancount:
            # we are looking very specifically for
            # an answer with TYPE=PTR, CLASS=IN (\000\014\000\001)
        for an in range(ancount):
            pos = skip_name (r, pos)
            if r[pos:pos+4] == '\000\014\000\001':
                return (
                        unpack_ttl (r,pos+4),
                        unpack_name (r, pos+10)
                        )
                # skip over TYPE, CLASS, TTL, RDLENGTH, RDATA
            pos = pos + 8
            rdlength = (ord(r[pos])<<8) + (ord(r[pos+1]))
            pos = pos + 2 + rdlength
        return 0, None
    else:
        return 0, None
        
        
        # This is a UDP (datagram) resolver.
        
        #
        # It may be useful to implement a TCP resolver.  This would presumably
        # give us more reliable behavior when things get too busy.  A TCP
        # client would have to manage the connection carefully, since the
        # server is allowed to close it at will (the RFC recommends closing
        # after 2 minutes of idle time).
        #
        # Note also that the TCP client will have to prepend each request
        # with a 2-byte length indicator (see rfc1035).
        #
        
class resolver (asyncore.dispatcher):
    id = counter()
    def __init__ (self, server='127.0.0.1'):
        asyncore.dispatcher.__init__ (self)
        self.create_socket (socket.AF_INET, socket.SOCK_DGRAM)
        self.server = server
        self.request_map = {}
        self.last_reap_time = int(time.time())      # reap every few minutes
        
    def writable (self):
        return 0
        
    def log (self, *args):
        pass
        
    def handle_close (self):
        self.log_info('closing!')
        self.close()
        
    def handle_error (self):      # don't close the connection on error
        (file,fun,line), t, v, tbinfo = asyncore.compact_traceback()
        self.log_info(
                        'Problem with DNS lookup (%s:%s %s)' % (t, v, tbinfo),
                        'error')
        
    def get_id (self):
        return (self.id.as_long() % (1<<16))
        
    def reap (self):          # find DNS requests that have timed out
        now = int(time.time())
        if now - self.last_reap_time > 180:        # reap every 3 minutes
            self.last_reap_time = now              # update before we forget
            for k,(host,unpack,callback,when) in self.request_map.items():
                if now - when > 180:               # over 3 minutes old
                    del self.request_map[k]
                    try:                           # same code as in handle_read
                        callback (host, 0, None)   # timeout val is (0,None) 
                    except:
                        (file,fun,line), t, v, tbinfo = asyncore.compact_traceback()
                        self.log_info('%s %s %s' % (t,v,tbinfo), 'error')
                        
    def resolve (self, host, callback):
        self.reap()                                # first, get rid of old guys
        self.socket.sendto (
                fast_address_request (host, self.get_id()),
                (self.server, 53)
                )
        self.request_map [self.get_id()] = (
                host, unpack_address_reply, callback, int(time.time()))
        self.id.increment()
        
    def resolve_ptr (self, host, callback):
        self.reap()                                # first, get rid of old guys
        ip = string.split (host, '.')
        ip.reverse()
        ip = string.join (ip, '.') + '.in-addr.arpa'
        self.socket.sendto (
                fast_ptr_request (ip, self.get_id()),
                (self.server, 53)
                )
        self.request_map [self.get_id()] = (
                host, unpack_ptr_reply, callback, int(time.time()))
        self.id.increment()
        
    def handle_read (self):
        reply, whence = self.socket.recvfrom (512)
        # for security reasons we may want to double-check
        # that <whence> is the server we sent the request to.
        id = (ord(reply[0])<<8) + ord(reply[1])
        if self.request_map.has_key (id):
            host, unpack, callback, when = self.request_map[id]
            del self.request_map[id]
            ttl, answer = unpack (reply)
            try:
                callback (host, ttl, answer)
            except:
                (file,fun,line), t, v, tbinfo = asyncore.compact_traceback()
                self.log_info('%s %s %s' % ( t,v,tbinfo), 'error')
                
class rbl (resolver):

    def resolve_maps (self, host, callback):
        ip = string.split (host, '.')
        ip.reverse()
        ip = string.join (ip, '.') + '.rbl.maps.vix.com'
        self.socket.sendto (
                fast_ptr_request (ip, self.get_id()),
                (self.server, 53)
                )
        self.request_map [self.get_id()] = host, self.check_reply, callback
        self.id.increment()
        
    def check_reply (self, r):
            # we only need to check RCODE.
        rcode = (ord(r[3])&0xf)
        self.log_info('MAPS RBL; RCODE =%02x\n %s' % (rcode, repr(r)))
        return 0, rcode # (ttl, answer)
        
        
class hooked_callback:
    def __init__ (self, hook, callback):
        self.hook, self.callback = hook, callback
        
    def __call__ (self, *args):
        apply (self.hook, args)
        apply (self.callback, args)
        
class caching_resolver (resolver):
    "Cache DNS queries.  Will need to honor the TTL value in the replies"
    
    def __init__ (*args):
        apply (resolver.__init__, args)
        self = args[0]
        self.cache = {}
        self.forward_requests = counter()
        self.reverse_requests = counter()
        self.cache_hits = counter()
        
    def resolve (self, host, callback):
        self.forward_requests.increment()
        if self.cache.has_key (host):
            when, ttl, answer = self.cache[host]
            # ignore TTL for now
            callback (host, ttl, answer)
            self.cache_hits.increment()
        else:
            resolver.resolve (
                    self,
                    host,
                    hooked_callback (
                            self.callback_hook,
                            callback
                            )
                    )
            
    def resolve_ptr (self, host, callback):
        self.reverse_requests.increment()
        if self.cache.has_key (host):
            when, ttl, answer = self.cache[host]
            # ignore TTL for now
            callback (host, ttl, answer)
            self.cache_hits.increment()
        else:
            resolver.resolve_ptr (
                    self,
                    host,
                    hooked_callback (
                            self.callback_hook,
                            callback
                            )
                    )
            
    def callback_hook (self, host, ttl, answer):
        self.cache[host] = time.time(), ttl, answer
        
    SERVER_IDENT = 'Caching DNS Resolver (V%s)' % VERSION
    
    def status (self):
        import status_handler
        import producers
        return producers.simple_producer (
                '<h2>%s</h2>'          % self.SERVER_IDENT
                + '<br>Server: %s'        % self.server
                + '<br>Cache Entries: %d'    % len(self.cache)
                + '<br>Outstanding Requests: %d' % len(self.request_map)
                + '<br>Forward Requests: %s'  % self.forward_requests
                + '<br>Reverse Requests: %s'  % self.reverse_requests
                + '<br>Cache Hits: %s'      % self.cache_hits
                )
        
        #test_reply = """\000\000\205\200\000\001\000\001\000\002\000\002\006squirl\011nightmare\003com\000\000\001\000\001\300\014\000\001\000\001\000\001Q\200\000\004\315\240\260\005\011nightmare\003com\000\000\002\000\001\000\001Q\200\000\002\300\014\3006\000\002\000\001\000\001Q\200\000\015\003ns1\003iag\003net\000\300\014\000\001\000\001\000\001Q\200\000\004\315\240\260\005\300]\000\001\000\001\000\000\350\227\000\004\314\033\322\005"""
        # def test_unpacker ():
        #   print unpack_address_reply (test_reply)
        # 
        # import time
        # class timer:
        #   def __init__ (self):
        #     self.start = time.time()
        #   def end (self):
        #     return time.time() - self.start
        # 
        # # I get ~290 unpacks per second for the typical case, compared to ~48
        # # using dnslib directly.  also, that latter number does not include
        # # picking the actual data out.
        # 
        # def benchmark_unpacker():
        # 
        #   r = range(1000)
        #   t = timer()
        #   for i in r:
        #     unpack_address_reply (test_reply)
        #   print '%.2f unpacks per second' % (1000.0 / t.end())
        
if __name__ == '__main__':
    import sys
    if len(sys.argv) == 1:
        print 'usage: %s [-r] [-s <server_IP>] host [host ...]' % sys.argv[0]
        sys.exit(0)
    elif ('-s' in sys.argv):
        i = sys.argv.index('-s')
        server = sys.argv[i+1]
        del sys.argv[i:i+2]
    else:
        server = '127.0.0.1'
        
    if ('-r' in sys.argv):
        reverse = 1
        i = sys.argv.index('-r')
        del sys.argv[i]
    else:
        reverse = 0
        
    if ('-m' in sys.argv):
        maps = 1
        sys.argv.remove ('-m')
    else:
        maps = 0
        
    if maps:
        r = rbl (server)
    else:
        r = caching_resolver(server)
        
    count = len(sys.argv) - 1
    
    def print_it (host, ttl, answer):
        global count
        print '%s: %s' % (host, answer)
        count = count - 1
        if not count:
            r.close()
            
    for host in sys.argv[1:]:
        if reverse:
            r.resolve_ptr (host, print_it)
        elif maps:
            r.resolve_maps (host, print_it)
        else:
            r.resolve (host, print_it)
            
            # hooked asyncore.loop()
    while asyncore.socket_map:
        asyncore.poll (30.0)
        print 'requests outstanding: %d' % len(r.request_map)
www.java2java.com | Contact Us
Copyright 2009 - 12 Demo Source and Support. All rights reserved.
All other trademarks are property of their respective owners.