# Twisted, the Framework of Your Internet
# Copyright (C) 2001 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
"""
Test cases for twisted.protocols package.
"""
from twisted.trial import unittest
from twisted.protocols import basic,wire
from twisted.internet import reactor,protocol
import string
import StringIO
class StringIOWithoutClosing(StringIO.StringIO):
def close(self):
pass
class LineTester(basic.LineReceiver):
delimiter = '\n'
MAX_LENGTH = 64
def connectionMade(self):
self.received = []
def lineReceived(self, line):
self.received.append(line)
if line == '':
self.setRawMode()
if line[:4] == 'len ':
self.length = int(line[4:])
def rawDataReceived(self, data):
data, rest = data[:self.length], data[self.length:]
self.length = self.length - len(data)
self.received[-1] = self.received[-1] + data
if self.length == 0:
self.setLineMode(rest)
def lineLengthExceeded(self, line):
if len(line) > self.MAX_LENGTH+1:
self.setLineMode(line[self.MAX_LENGTH+1:])
class WireTestCase(unittest.TestCase):
def testEcho(self):
t = StringIOWithoutClosing()
a = wire.Echo()
a.makeConnection(protocol.FileWrapper(t))
a.dataReceived("hello")
a.dataReceived("world")
a.dataReceived("how")
a.dataReceived("are")
a.dataReceived("you")
self.failUnlessEqual(t.getvalue(), "helloworldhowareyou")
def testWho(self):
t = StringIOWithoutClosing()
a = wire.Who()
a.makeConnection(protocol.FileWrapper(t))
self.failUnlessEqual(t.getvalue(), "root\r\n")
def testQOTD(self):
t = StringIOWithoutClosing()
a = wire.QOTD()
a.makeConnection(protocol.FileWrapper(t))
self.failUnlessEqual(t.getvalue(),
"An apple a day keeps the doctor away.\r\n")
def testDiscard(self):
t = StringIOWithoutClosing()
a = wire.Discard()
a.makeConnection(protocol.FileWrapper(t))
a.dataReceived("hello")
a.dataReceived("world")
a.dataReceived("how")
a.dataReceived("are")
a.dataReceived("you")
self.failUnlessEqual(t.getvalue(), "")
class LineReceiverTestCase(unittest.TestCase):
buffer = '''\
len 10
0123456789len 5
1234
len 20
foo 123
0123456789
012345678len 0
foo 5
1234567890123456789012345678901234567890123456789012345678901234567890
len 1
a'''
output = ['len 10', '0123456789', 'len 5', '1234\n',
'len 20', 'foo 123', '0123456789\n012345678',
'len 0', 'foo 5', '', '67890', 'len 1', 'a']
def testBuffer(self):
for packet_size in range(1, 10):
t = StringIOWithoutClosing()
a = LineTester()
a.makeConnection(protocol.FileWrapper(t))
for i in range(len(self.buffer)/packet_size + 1):
s = self.buffer[i*packet_size:(i+1)*packet_size]
a.dataReceived(s)
self.failUnlessEqual(self.output, a.received)
class TestNetstring(basic.NetstringReceiver):
def connectionMade(self):
self.received = []
def stringReceived(self, s):
self.received.append(s)
class TestSafeNetstring(basic.SafeNetstringReceiver):
MAX_LENGTH = 50
closed = 0
def stringReceived(self, s):
pass
def connectionLost(self):
self.closed = 1
class NetstringReceiverTestCase(unittest.TestCase):
strings = ['hello', 'world', 'how', 'are', 'you123', ':today', "a"*515]
illegal_strings = ['9999999999999999999999', 'abc', '4:abcde',
'51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',]
def testBuffer(self):
for packet_size in range(1, 10):
t = StringIOWithoutClosing()
a = TestNetstring()
a.makeConnection(protocol.FileWrapper(t))
for s in self.strings:
a.sendString(s)
out = t.getvalue()
for i in range(len(out)/packet_size + 1):
s = out[i*packet_size:(i+1)*packet_size]
if s:
a.dataReceived(s)
if a.received != self.strings:
raise AssertionError(a.received)
def getSafeNS(self):
t = StringIOWithoutClosing()
a = TestSafeNetstring()
a.makeConnection(protocol.FileWrapper(t))
return a
def testSafe(self):
for s in self.illegal_strings:
r = self.getSafeNS()
r.dataReceived(s)
if not r.brokenPeer:
raise AssertionError("connection wasn't closed on illegal netstring %s" % repr(s))
if __name__ == '__main__':
unittest.main()
|