# Twisted, the Framework of Your Internet
# Copyright (C) 2001-2002 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
"""Resource limiting policies."""
# system imports
import sys, operator
# twisted imports
from twisted.internet.protocol import ServerFactory,Protocol,ClientFactory
from twisted.internet.interfaces import ITransport
from twisted.internet import reactor
from twisted.python import log
class ProtocolWrapper(Protocol):
"""Wraps protocol instances and acts as their transport as well."""
__implements__ = ITransport,
disconnecting = 0
def __init__(self, factory, wrappedProtocol):
self.wrappedProtocol = wrappedProtocol
self.factory = factory
# Transport relaying
def write(self, data):
self.transport.write(data)
def writeSequence(self, data):
self.transport.writeSequence(data)
def loseConnection(self):
self.disconnecting = 1
self.transport.loseConnection()
def getPeer(self):
return self.transport.getPeer()
def getHost(self):
return self.transport.getHost()
def registerProducer(self, producer, streaming):
self.transport.registerProducer(producer, streaming)
def unregisterProducer(self):
self.transport.unregisterProducer()
def stopConsuming(self):
self.transport.stopConsuming()
# Protocol relaying
def connectionMade(self):
self.factory.registerProtocol(self)
self.wrappedProtocol.makeConnection(self)
def dataReceived(self, data):
self.wrappedProtocol.dataReceived(data)
def connectionLost(self, reason):
self.factory.unregisterProtocol(self)
self.wrappedProtocol.connectionLost(reason)
class WrappingFactory(ClientFactory):
"""Wraps a factory and its protocols, and keeps track of them."""
protocol = ProtocolWrapper
def __init__(self, wrappedFactory):
self.wrappedFactory = wrappedFactory
self.protocols = {}
def startedConnecting(self, connector):
self.wrappedFactory.startedConnecting(connector)
def clientConnectionFailed(self, connector, reason):
self.wrappedFactory.clientConnectionFailed(connector, reason)
def clientConnectionLost(self, connector, reason):
self.wrappedFactory.clientConnectionLost(connector, reason)
def buildProtocol(self, addr):
return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
def registerProtocol(self, p):
"""Called by protocol to register itself."""
self.protocols[p] = 1
def unregisterProtocol(self, p):
"""Called by protocols when they go away."""
del self.protocols[p]
class ThrottlingProtocol(ProtocolWrapper):
"""Protocol for ThrottlingFactory."""
# wrap API for tracking bandwidth
def write(self, data):
self.factory.registerWritten(len(data))
ProtocolWrapper.write(self, data)
def writeSequence(self, seq):
self.factory.registerWritten(reduce(operator.add, map(len, seq)))
ProtocolWrapper.writeSequence(self, seq)
def dataReceived(self, data):
self.factory.registerRead(len(data))
ProtocolWrapper.dataReceived(self, data)
def registerProducer(self, producer, streaming):
self.producer = producer
ProtocolWrapper.registerProducer(self, producer, streaming)
def unregisterProducer(self):
del self.producer
ProtocolWrapper.unregisterProducer(self)
def throttleReads(self):
self.transport.stopReading()
def unthrottleReads(self):
self.transport.startReading()
def throttleWrites(self):
if hasattr(self, "producer"):
self.producer.pauseProducing()
def unthrottleWrites(self):
if hasattr(self, "producer"):
self.producer.resumeProducing()
class ThrottlingFactory(WrappingFactory):
"""Throttles bandwidth and number of connections.
Write bandwidth will only be throttled if there is a producer
registered.
"""
protocol = ThrottlingProtocol
def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, readLimit=None, writeLimit=None):
WrappingFactory.__init__(self, wrappedFactory)
self.connectionCount = 0
self.maxConnectionCount = maxConnectionCount
self.readLimit = readLimit # max bytes we should read per second
self.writeLimit = writeLimit # max bytes we should write per second
self.readThisSecond = 0
self.writtenThisSecond = 0
self.unthrottleReadsID = None
self.checkReadBandwidthID = None
self.unthrottleWritesID = None
self.checkWriteBandwidthID = None
def registerWritten(self, length):
"""Called by protocol to tell us more bytes were written."""
self.writtenThisSecond += length
def registerRead(self, length):
"""Called by protocol to tell us more bytes were read."""
self.readThisSecond += length
def checkReadBandwidth(self):
"""Checks if we've passed bandwidth limits."""
if self.readThisSecond > self.readLimit:
self.throttleReads()
throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
self.unthrottleReadsID = reactor.callLater(throttleTime,
self.unthrottleReads)
self.readThisSecond = 0
self.checkReadBandwidthID = reactor.callLater(1, self.checkReadBandwidth)
def checkWriteBandwidth(self):
if self.writtenThisSecond > self.writeLimit:
self.throttleWrites()
throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
self.unthrottleWritesID = reactor.callLater(throttleTime,
self.unthrottleWrites)
# reset for next round
self.writtenThisSecond = 0
self.checkWriteBandwidthID = reactor.callLater(1, self.checkWriteBandwidth)
def throttleReads(self):
"""Throttle reads on all protocols."""
log.msg("Throttling reads on %s" % self)
for p in self.protocols.keys():
p.throttleReads()
def unthrottleReads(self):
"""Stop throttling reads on all protocols."""
self.unthrottleReadsID = None
log.msg("Stopped throttling reads on %s" % self)
for p in self.protocols.keys():
p.unthrottleReads()
def throttleWrites(self):
"""Throttle writes on all protocols."""
log.msg("Throttling writes on %s" % self)
for p in self.protocols.keys():
p.throttleWrites()
def unthrottleWrites(self):
"""Stop throttling writes on all protocols."""
self.unthrottleWritesID = None
log.msg("Stopped throttling writes on %s" % self)
for p in self.protocols.keys():
p.unthrottleWrites()
def buildProtocol(self, addr):
if self.connectionCount == 0:
if self.readLimit is not None:
self.checkReadBandwidth()
if self.writeLimit is not None:
self.checkWriteBandwidth()
if self.connectionCount < self.maxConnectionCount:
self.connectionCount += 1
return WrappingFactory.buildProtocol(self, addr)
else:
log.msg("Max connection count reached!")
return None
def unregisterProtocol(self, p):
WrappingFactory.unregisterProtocol(self, p)
self.connectionCount -= 1
if self.connectionCount == 0:
if self.unthrottleReadsID is not None:
self.unthrottleReadsID.cancel()
if self.checkReadBandwidthID is not None:
self.checkReadBandwidthID.cancel()
if self.unthrottleWritesID is not None:
self.unthrottleWritesID.cancel()
if self.checkWriteBandwidthID is not None:
self.checkWriteBandwidthID.cancel()
class SpewingProtocol(ProtocolWrapper):
def dataReceived(self, data):
log.msg("Received: %r" % data)
ProtocolWrapper.dataReceived(self,data)
def write(self, data):
log.msg("Sending: %r" % data)
ProtocolWrapper.write(self,data)
class SpewingFactory(WrappingFactory):
protocol = SpewingProtocol
|