# Twisted, the Framework of Your Internet
# Copyright (C) 2001 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
"""Basic protocols, such as line-oriented, netstring, and 32-bit-int prefixed strings.
API Stability: semi-stable.
Maintainer: U{Itamar Shtull-Trauring<mailto:twisted@itamarst.org>}
"""
# System imports
import string
import re
import struct
# Twisted imports
from twisted.internet import protocol
from twisted.python import log
LENGTH, DATA, COMMA = range(3)
NUMBER = re.compile('(\d*)(:?)')
DEBUG = 0
class NetstringParseError(ValueError):
"""The incoming data is not in valid Netstring format."""
pass
class NetstringReceiver(protocol.Protocol):
"""This uses djb's Netstrings protocol to break up the input into strings.
Each string makes a callback to stringReceived, with a single
argument of that string.
Security features:
1. Messages are limited in size, useful if you don't want someone
sending you a 500MB netstring (change MAX_LENGTH to the maximum
length you wish to accept).
2. The connection is lost if an illegal message is received.
"""
MAX_LENGTH = 99999
brokenPeer = 0
_readerState = LENGTH
_readerLength = 0
def stringReceived(self, line):
"""
Override this.
"""
raise NotImplementedError
def doData(self):
buffer,self.__data = self.__data[:int(self._readerLength)],self.__data[int(self._readerLength):]
self._readerLength = self._readerLength - len(buffer)
self.__buffer = self.__buffer + buffer
if self._readerLength != 0:
return
self.stringReceived(self.__buffer)
self._readerState = COMMA
def doComma(self):
self._readerState = LENGTH
if self.__data[0] != ',':
if DEBUG:
raise NetstringParseError(repr(self.__data))
else:
raise NetstringParseError
self.__data = self.__data[1:]
def doLength(self):
m = NUMBER.match(self.__data)
if not m.end():
if DEBUG:
raise NetstringParseError(repr(self.__data))
else:
raise NetstringParseError
self.__data = self.__data[m.end():]
if m.group(1):
try:
self._readerLength = self._readerLength * (10**len(m.group(1))) + long(m.group(1))
except OverflowError:
raise NetstringParseError, "netstring too long"
if self._readerLength > self.MAX_LENGTH:
raise NetstringParseError, "netstring too long"
if m.group(2):
self.__buffer = ''
self._readerState = DATA
def dataReceived(self, data):
self.__data = data
try:
while self.__data:
if self._readerState == DATA:
self.doData()
elif self._readerState == COMMA:
self.doComma()
elif self._readerState == LENGTH:
self.doLength()
else:
raise RuntimeError, "mode is not DATA, COMMA or LENGTH"
except NetstringParseError:
self.transport.loseConnection()
self.brokenPeer = 1
def sendString(self, data):
self.transport.write('%d:%s,' % (len(data), data))
class SafeNetstringReceiver(NetstringReceiver):
"""This class is deprecated, use NetstringReceiver instead.
"""
class LineReceiver(protocol.Protocol):
"""A protocol that receives lines and/or raw data, depending on mode.
In line mode, each line that's received becomes a callback to
L{lineReceived}. In raw data mode, each chunk of raw data becomes a
callback to L{rawDataReceived}. The L{setLineMode} and L{setRawMode}
methods switch between the two modes.
This is useful for line-oriented protocols such as IRC, HTTP, POP, etc.
@cvar delimiter: The line-ending delimiter to use. By default this is
'\\r\\n'.
@cvar MAX_LENGTH: The maximum length of a line to allow (If a
sent line is longer than this, the connection is dropped).
Default is 16834.
"""
line_mode = 1
__buffer = ''
delimiter = '\r\n'
MAX_LENGTH = 16384
def dataReceived(self, data):
"""Protocol.dataReceived.
Translates bytes into lines, and calls lineReceived (or
rawDataReceived, depending on mode.)
"""
self.__buffer = self.__buffer+data
while self.line_mode:
try:
line, self.__buffer = self.__buffer.split(self.delimiter, 1)
except ValueError:
if len(self.__buffer) > self.MAX_LENGTH:
line, self.__buffer = self.__buffer, ''
self.lineLengthExceeded(line)
return
break
else:
linelength = len(line)
if linelength > self.MAX_LENGTH:
line, self.__buffer = self.__buffer, ''
self.lineLengthExceeded(line)
return
self.lineReceived(line)
if self.transport.disconnecting:
return
else:
data, self.__buffer = self.__buffer, ''
if data:
return self.rawDataReceived(data)
def setLineMode(self, extra=''):
"""Sets the line-mode of this receiver.
If you are calling this from a rawDataReceived callback,
you can pass in extra unhandled data, and that data will
be parsed for lines. Further data received will be sent
to lineReceived rather than rawDataReceived.
"""
self.line_mode = 1
return self.dataReceived(extra)
def setRawMode(self):
"""Sets the raw mode of this receiver.
Further data received will be sent to rawDataReceived rather
than lineReceived.
"""
self.line_mode = 0
def rawDataReceived(self, data):
"""Override this for when raw data is received.
"""
raise NotImplementedError
def lineReceived(self, line):
"""Override this for when each line is received.
"""
raise NotImplementedError
def sendLine(self, line):
"""Sends a line to the other end of the connection.
"""
self.transport.write(line + self.delimiter)
def lineLengthExceeded(self, line):
"""Called when the maximum line length has been reached.
Override if it needs to be dealt with in some special way.
"""
self.transport.loseConnection()
class Int32StringReceiver(protocol.Protocol):
"""A receiver for int32-prefixed strings.
An int32 string is a string prefixed by 4 bytes, the 32-bit length of
the string encoded in network byte order.
This class publishes the same interface as NetstringReceiver.
"""
MAX_LENGTH = 99999
recvd = ""
def stringReceived(self, msg):
"""Override this.
"""
raise NotImplementedError
def dataReceived(self, recd):
"""Convert int32 prefixed strings into calls to stringReceived.
"""
self.recvd = self.recvd + recd
while len(self.recvd) > 3:
length ,= struct.unpack("!i",self.recvd[:4])
if length > self.MAX_LENGTH:
self.transport.loseConnection()
return
if len(self.recvd) < length+4:
break
packet = self.recvd[4:length+4]
self.recvd = self.recvd[length+4:]
self.stringReceived(packet)
def sendString(self, data):
"""Send an int32-prefixed string to the other end of the connection.
"""
self.transport.write(struct.pack("!i",len(data))+data)
class Int16StringReceiver(protocol.Protocol):
"""A receiver for int16-prefixed strings.
An int16 string is a string prefixed by 2 bytes, the 16-bit length of
the string encoded in network byte order.
This class publishes the same interface as NetstringReceiver.
"""
recvd = ""
def stringReceived(self, msg):
"""Override this.
"""
raise NotImplementedError
def dataReceived(self, recd):
"""Convert int16 prefixed strings into calls to stringReceived.
"""
self.recvd = self.recvd + recd
while len(self.recvd) > 1:
length = (ord(self.recvd[0]) * 256) + ord(self.recvd[1])
if len(self.recvd) < length+2:
break
packet = self.recvd[2:length+2]
self.recvd = self.recvd[length+2:]
self.stringReceived(packet)
def sendString(self, data):
"""Send an int16-prefixed string to the other end of the connection.
"""
assert len(data) < 65536, "message too long"
self.transport.write(struct.pack("!h",len(data)) + data)
class StatefulStringProtocol:
"""A stateful string protocol.
This is a mixin for string protocols (Int32StringReceiver,
NetstringReceiver) which translates stringReceived into a callback
(prefixed with 'proto_') depending on state."""
state = 'init'
def stringReceived(self,string):
"""Choose a protocol phase function and call it.
Call back to the appropriate protocol phase; this begins with
the function proto_init and moves on to proto_* depending on
what each proto_* function returns. (For example, if
self.proto_init returns 'foo', then self.proto_foo will be the
next function called when a protocol message is received.
"""
try:
pto = 'proto_'+self.state
statehandler = getattr(self,pto)
except AttributeError:
log.msg('callback',self.state,'not found')
else:
self.state = statehandler(string)
if self.state == 'done':
self.transport.loseConnection()
|