[Python] 多线程 NS 查询工具

作者 huhamhire,暂无评论,2014年1月27日 16:44 程序实践

前段时间正好考虑了一下后面 hosts 工具的远期目标,感觉以后如果要朝着自动化维护的方向推进的话,首先需要的就是可以自动获取 hosts 的条目,并且可以对这些条目进行一些基础的访问测试。那么这些 hosts 的数据要怎么获取呢?hosts 文件中间的主要内容就是服务器的 IP 地址与其对应的域名。经过这些年的积累,手上目前积累的域名差不多有数千条,域名的问题暂时不用考虑。IP 地址的自动获取,考虑到实时有效性,自然需要通过查询 DNS 来实现。

出于上面提到的需求,我就尝试用 Python 实现了一个类似于 nslookup 的工具。有个问题是,既然很多系统都提供了 nslookup 甚至是 dig 这样的工具,为嘛我要自己来实现呢?最主要的原因还是效率的问题,毕竟我需要在最短的时间内查询数千个域名的对应 IP,并且需要查询不止一个服务器。出于速度上的考虑,就有必要使用多线程来实现。

目前实现了一个简单的工具,一共写了三个类。首先就是 NSTools 对象,这个类功能比较简单,主要包含了对于 NS 查询数据包编码及解码的静态方法。


 1 class NSTools(object):
 2 
 3     @staticmethod
 4     def encode(host_name):
 5         index = os.urandom(2)
 6         host_str = ''.join(chr(len(x)) + x for x in host_name.split('.'))
 7         data = "%s\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00%s" \
 8                "\x00\x00\x01\x00\x01" % (index, host_str)
 9         data = struct.pack("!H", len(data)) + data
10         return data
11 
12     @staticmethod
13     def decode(in_sock):
14         in_file = in_sock.makefile("rb")
15         size = struct.unpack("!H", in_file.read(2))[0]
16         data = in_file.read(size)
17         ip_list = re.findall("\xC0.\x00\x01\x00\x01.{6}(.{4})", data)
18         return [".".join(str(ord(x)) for x in s) for s in ip_list]

目前这个类仅实现了最基础的功能,只支持 A 记录(IPv4)的查询后续可能还会增加 AAAA 记录查询的功能。另外,需要说明的是,大部分网站的域名可以在 DNS 处获得不止一条的返回值,这里需要将所有返回的有效信息都记录下来。

接下来,就是比较关键的用于实现多线程查询的类。由于是多线程实现,这里给每个实例输入一个 DNS 服务器列表,以及需要查询的一条域名信息。程序自动使用相应端口向被查询的服务器发送查询数据包,并从服务器的返回信息中获取域名对应的 IP 地址信息。


 1 class NSLookup(threading.Thread):
 2     ERROR_DESC = {
 3         10054: 'ERROR: Connection reset by peer',
 4     }
 5     STATUS_DESC = {
 6         0: "OK",
 7         1: "No Hit",
 8         2: "Timed Out",
 9         3: "Conn Error",
10         4: "Decode Error"
11     }
12 
13     def __init__(self, servers, host_name, results, semaphore,
14                  ipv6=False, timeout=2, sock_type="TCP", port=53):
15         threading.Thread.__init__(self)
16         self.servers = servers
17         self.port = port
18         self.host_name = host_name
19         self.results = results
20         self.sem = semaphore
21         self.timeout = timeout
22         # Set IP family
23         if ipv6:
24             self.ip_family = socket.AF_INET6
25         else:
26             self.ip_family = socket.AF_INET
27             # Set socket type
28         if "TCP" == sock_type.upper():
29             self.sock_type = socket.SOCK_STREAM
30         else:
31             self.sock_type = socket.SOCK_DGRAM
32         self.results[host_name] = {}
33 
34     @property
35     def __sock(self):
36         try:
37             sock = socket.socket(self.ip_family, self.sock_type)
38             sock.settimeout(self.timeout)
39             return sock
40         except socket.error, (error_no, msg):
41             sys.stdout.write("\r  host: %s, Error %d: %s\n" %
42                              (self.host_name, error_no, msg))
43             raise
44 
45     def lookup(self, server_ip):
46         sock = self.__sock
47         try:
48             sock.connect((server_ip, self.port))
49             sock.sendall(NSTools.encode(self.host_name))
50             hosts = NSTools.decode(sock)
51             self._response["hosts"] = hosts
52             if hosts:
53                 # Set status OK
54                 self._response["stat"] = 0
55             else:
56                 # Set status No Results
57                 self._response["stat"] = 1
58         except socket.timeout:
59             # Set status Timeout
60             self._response["stat"] = 2
61         except socket.error:
62             # Set status Connection Error
63             self._response["stat"] = 3
64         except struct.error:
65             # Set status Decode Error
66             self._response["stat"] = 4
67         finally:
68             sock.close()
69 
70     def show_state(self, server_tag):
71         stat = self.STATUS_DESC[self._response["stat"]]
72         msg = "NSLK: " + self.host_name + " - " + server_tag
73         if stat == "OK":
74             Progress.show_status(msg, stat)
75         else:
76             Progress.show_status(msg, stat, 1)
77         Progress.progress_bar()
78 
79     def run(self):
80         responses = {}
81         for tag, ip in self.servers.iteritems():
82             self._response = {"hosts": [], "stat": 1}
83             self.lookup(ip)
84             responses[tag] = self._response
85             self.show_state(tag)
86         self.results[self.host_name] = responses
87         self.sem.release()

从代码中可以看到,这个类目前也在部分位置提供了 IPv6 的选项,不过具体查询功能还没有实现,这部分就留给后面解决了。另外,目前的查询是通过 TCP 实现的,以后有条件应该会考虑尽量使用 UDP 来解决。

如果在查询的过程中,相关的条目出现异常的情况,比如超时或者没有返回值,程序也会自动对异常进行记录。

下面最后一个类,主要实现多线程调度的管理,也是整个 NS 查询的入口。需要注意一个问题,因为 DNS 服务器也是服务器,如果在大量并发查询域名记录的时候不对并发数量加以限制,服务器必然会将你的查询视为类似 DDoS 的攻击,就有可能会暂时屏蔽你的请求,或者直接屏蔽你的 IP。无论哪种情况,一旦发生,都会导致相关查询无法继续,所以需要使用类似线程池的机制来控制并发数量。

我看到好多文章都是用 queue 来实现的线程池,其实像这里类似的只涉及并发资源数量的多线程问题,完全可以使用信号量(Semaphore)来管理。Python 的 threading 模块也直接提供了相关的方法。使用了信号量,整个并发过程就可以理解成一个典型的生产者消费者问题。只需使用一个信号量参数就可以控制活动的线程数量,新线程的创建将一直被阻塞,只到有活动现成线程的操作已经完成。


 1 class MultiNSLookup(object):
 2     # Limit the number of concurrent sessions
 3     sem = threading.Semaphore(0x20)
 4 
 5     def __init__(self, ns_servers, host_names):
 6         self.ns_servers = ns_servers
 7         self.host_names = host_names
 8         self._responses = {}
 9 
10     def nslookup(self):
11         Progress.set_total(len(self.host_names))
12         Progress.set_counter(self._responses)
13         threads = []
14         for domain in self.host_names:
15             self.sem.acquire()
16             lookup_host = NSLookup(
17                 self.ns_servers, domain, self._responses, self.sem)
18             lookup_host.start()
19             threads.append(lookup_host)
20 
21         for lookup_host in threads:
22             lookup_host.join()
23 
24         Progress.progress_bar()
25         return self._responses

这里我总结了几个 DNS 服务器查询的结果进行了并发数量配置,很可怜的只能限制在 32 (0x20) 个。如果太多的话,部分服务器可能会间歇性拒绝服务。如果是查询的服务器没有这么高的限制,适当放宽并发限制也是可以的。不过即使是同时执行 32 个线程,其效率也要比单线程执行高至少一个数量级。

最后再给出我目前测试用的主程序部分。这里的很多地方都用到了数据库来存储数据,其实主要也就是域名条目的获取,以及相关结果的储存,只要理解即可。另外就是可以在主程序中手动配置各地的 DNS 列表,比如我这里为了跟踪全球各个节点的情况,查询的都是各地的服务器。一般建议同一区域内的服务器没有必要大量重复查询,因为这样本身会产生大量的重复结果,虽然不会重复存入数据库,但是会比较严重的应该操作效率。


 1 if __name__ == '__main__':
 2     SourceData.connect_db()
 3     SourceData.drop_tables()
 4     SourceData.clear()
 5     SourceData.create_tables()
 6 
 7     ns_servers = {
 8         "us": "64.118.80.141",
 9         "uk": "62.140.195.84",
10         "de": "62.128.1.42",
11         "fr": "82.216.111.121",
12 
13         "cn": "211.157.15.189",
14         "hk": "203.80.96.10",
15         "tw": "168.95.192.1",
16         "jp": "158.205.225.226",
17         "sg": "165.21.83.88",
18         "kr": "115.68.45.3",
19         "in": "58.68.121.230",
20     }
21     cfg_file = "mods.xml"
22     set_domain = SetDomain(cfg_file)
23     set_domain.get_config()
24     set_domain.get_domains_in_mods()
25     domains = SourceData.get_domain_list()
26 
27     lookups = MultiNSLookup(ns_servers, domains)
28     responses = lookups.nslookup()
29 
30     SourceData.set_multi_ns_response(responses)

最后是查询过程中的截图。额外说明一下,用于在终端显示查询进度的对象就不在这里额外介绍了,实现起来也比较简单。

ns
关键词:DNS , Python , 工具 DIY , 网络测试
登录后进行评论