transport.py :  » Network » Twisted » Twisted-1.0.3 » Twisted-1.0.3 » twisted » conch » ssh » Python Open Source

Home
Python Open Source
1.3.1.2 Python
2.Ajax
3.Aspect Oriented
4.Blog
5.Build
6.Business Application
7.Chart Report
8.Content Management Systems
9.Cryptographic
10.Database
11.Development
12.Editor
13.Email
14.ERP
15.Game 2D 3D
16.GIS
17.GUI
18.IDE
19.Installer
20.IRC
21.Issue Tracker
22.Language Interface
23.Log
24.Math
25.Media Sound Audio
26.Mobile
27.Network
28.Parser
29.PDF
30.Project Management
31.RSS
32.Search
33.Security
34.Template Engines
35.Test
36.UML
37.USB Serial
38.Web Frameworks
39.Web Server
40.Web Services
41.Web Unit
42.Wiki
43.Windows
44.XML
Python Open Source » Network » Twisted 
Twisted » Twisted 1.0.3 » Twisted 1.0.3 » twisted » conch » ssh » transport.py
# Twisted, the Framework of Your Internet
# Copyright (C) 2001-2002 Matthew W. Lefkowitz
# 
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
# 
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
# 

"""The lowest level SSH protocol.  This handles the key negotiation, the encryption and the compression.

This module is unstable.

Maintainer: U{Paul Swartz<mailto:z3p@twistedmatrix.com>}
"""

# base library imports
import struct
import md5
import sha
import zlib

# external library imports
from Crypto import Util
from Crypto.Cipher import XOR
from Crypto.PublicKey import RSA
from Crypto.Util import randpool

# twisted imports
from twisted.conch import error
from twisted.internet import protocol,defer
from twisted.python import log

# sibling importsa
from common import NS,getNS,MP,getMP,_MPpow,ffs,entropy# ease of use
keys


class SSHTransportBase(protocol.Protocol):
    protocolVersion = '2.0'
    version = 'Twisted'
    comment = ''
    ourVersionString = ('SSH-'+protocolVersion+'-'+version+' '+comment).strip()

    supportedCiphers = ['aes256-cbc', 'aes192-cbc', 'aes128-cbc', 'cast128-cbc', 
                        'blowfish', 'idea-cbc', '3des-cbc']
    supportedMACs = ['hmac-sha1', 'hmac-md5']
    supportedKeyExchanges = ['diffie-hellman-group-exchange-sha1', 
                             'diffie-hellman-group1-sha1']
    supportedPublicKeys = ['ssh-rsa', 'ssh-dss']
    supportedCompressions = ['none', 'zlib']
    supportedLanguages = ()

    gotVersion = 0
    ignoreNextPacket = 0
    buf = ''
    outgoingPacketSequence = 0
    incomingPacketSequence = 0
    currentEncryptions = None
    outgoingCompression = None
    incomingCompression = None
    sessionID = None
    isAuthorized = 0
    service = None

    def connectionLost(self, reason):
        if self.service:
            self.service.serviceStopped()
        log.msg('connection lost')

    def connectionMade(self):
        self.transport.write('%s\r\n'%(self.ourVersionString))
        self.sendKexInit()

    def sendKexInit(self):
        self.ourKexInitPayload = chr(MSG_KEXINIT)+entropy.get_bytes(16)+ \
                       NS(','.join(self.supportedKeyExchanges))+ \
                       NS(','.join(self.supportedPublicKeys))+ \
                       NS(','.join(self.supportedCiphers))+ \
                       NS(','.join(self.supportedCiphers))+ \
                       NS(','.join(self.supportedMACs))+ \
                       NS(','.join(self.supportedMACs))+ \
                       NS(','.join(self.supportedCompressions))+ \
                       NS(','.join(self.supportedCompressions))+ \
                       NS(','.join(self.supportedLanguages))+ \
                       NS(','.join(self.supportedLanguages))+ \
                       '\000'+'\000\000\000\000'
        self.sendPacket(MSG_KEXINIT, self.ourKexInitPayload[1:])

    def sendPacket(self, messageType, payload):
        payload = chr(messageType)+payload
        if self.outgoingCompression:
            payload = self.outgoingCompression.compress(payload)
            payload = payload+self.outgoingCompression.flush(2)
        if self.currentEncryptions:
            bs = self.currentEncryptions.enc_block_size
        else:
            bs = 8
        totalSize = 5+len(payload)
        lenPad = bs-(totalSize%bs)
        if lenPad < 4:
            lenPad = lenPad+bs
        randomPad = entropy.get_bytes(lenPad)
        packet = struct.pack('!LB', 1+len(payload)+lenPad, lenPad)+ \
                payload+randomPad
        assert len(packet)%bs == 0, '%s extra bytes in packet'%(len(packet)%bs)
        if self.currentEncryptions:
            encPacket = self.currentEncryptions.encrypt(packet)
            assert len(encPacket) == len(packet), '%s %s'%(len(encPacket), len(packet))
        else:
            encPacket = packet
        if self.currentEncryptions:
            d = self.currentEncryptions.makeMAC(self.outgoingPacketSequence, packet)
            encPacket = encPacket+d
        self.transport.write(encPacket)
        self.outgoingPacketSequence+=1

    def getPacket(self):
        bs = self.currentEncryptions and self.currentEncryptions.dec_block_size or 8
        ms = self.currentEncryptions and self.currentEncryptions.verify_digest_size or 0
        if len(self.buf) < bs: return # not enough data
        if not hasattr(self, 'first'):
            if self.currentEncryptions:
                first = self.currentEncryptions.decrypt(self.buf[: bs])
            else:
                first = self.buf[: bs]
        else:
            first = self.first
            del self.first
        packetLen, randomLen = struct.unpack('!LB', first[: 5])
        if packetLen > 1048576: # 1024 ** 2
            self.sendDisconnect(DISCONNECT_PROTOCOL_ERROR, 'bad packet length %s'%packetLen)
            return
        if len(self.buf) < packetLen+4+ms:
            self.first = first
            return # not enough packet
        if(packetLen+4)%bs != 0:
            self.sendDisconnect(DISCONNECT_PROTOCOL_ERROR, 'bad packet mod (%s%%%s == %s'%(packetLen+4, bs, (packetLen+4)%bs))
            return
        encData, self.buf = self.buf[: 4+packetLen], self.buf[4+packetLen:]
        if self.currentEncryptions:
            packet = first+self.currentEncryptions.decrypt(encData[bs:])
        else:
            packet = encData
        if len(packet) != 4+packetLen:
            self.sendDisconnect(DISCONNECT_PROTOCOL_ERROR, 'bad packet length')
            return
        if ms: 
            macData, self.buf = self.buf[:ms],  self.buf[ms:]
            if not self.currentEncryptions.verify(self.incomingPacketSequence, packet, macData):
                self.sendDisconnect(DISCONNECT_MAC_ERROR, 'bad MAC')
                return
        payload = packet[5: 4+packetLen-randomLen]
        if self.incomingCompression:
            try:
                payload = self.incomingCompression.decompress(payload)
            except zlib.error:
                self.sendDisconnect(DISCONNECT_COMPRESSION_ERROR, 'compression error')
                return
        self.incomingPacketSequence+=1
        return payload

    def dataReceived(self, data):
        self.buf = self.buf+data
        if not self.gotVersion:
            parts = self.buf.split('\n')
            for p in parts:
                if p[: 4] == 'SSH-':
                    self.gotVersion = 1
                    self.otherVersionString = p.strip()
                    if p.split('-')[1]not in('1.99', '2.0'): # bad version
                        self.sendDisconnect(DISCONNECT_PROTOCOL_VERSION_NOT_SUPPORTED, 'bad version %s'%p.split('-')[1])
                        return
                    i = parts.index(p)
                    self.buf = '\n'.join(parts[i+1:])
        packet = self.getPacket()
        while packet:
            messageNum = ord(packet[0])
            if messageNum < 50:
                messageType = messages[messageNum][4:]
                f = getattr(self, 'ssh_%s'%messageType, None)
                if f:
                    f(packet[1:])
                else:
                    log.msg("couldn't handle %s"%messageType)
                    log.msg(repr(packet[1:]))
                    self.sendUnimplemented()
            elif self.service:
                self.service.packetReceived(ord(packet[0]), packet[1:])
            else:
                log.msg("couldn't handle %s"%messageNum)
                log.msg(repr(packet[1:]))
                self.sendUnimplemented()
            packet = self.getPacket()

    def ssh_DISCONNECT(self, packet):
        reasonCode = struct.unpack('>L', packet[: 4])[0]
        description, foo = getNS(packet[4:])
        self.receiveError(reasonCode, description)
        self.transport.loseConnection()

    def ssh_IGNORE(self, packet): pass

    def ssh_UNIMPLEMENTED(self, packet):
        seqnum = struct.unpack('>L', packet)
        self.receiveUnimplemented(seqnum)

    def ssh_DEBUG(self, packet):
        alwaysDisplay = ord(packet[0])
        message, lang, foo = getNS(packet, 2)
        self.receiveDebug(alwaysDisplay, message, lang)

    def setService(self, service):
        log.msg('starting service %s'%service.name)
        if self.service:
            self.service.serviceStopped()
        self.service = service
        service.transport = self
        self.service.serviceStarted()

    def sendDebug(self, message, alwaysDisplay = 0, language = ''):
        self.sendPacket(MSG_DEBUG, chr(alwaysDisplay)+NS(message)+NS(language))

    def sendIgnore(self, message):
        self.sendPacket(MSG_IGNORE, NS(message))

    def sendUnimplemented(self):
        seqnum = self.incomingPacketSequence
        self.sendPacket(MSG_UNIMPLEMENTED, struct.pack('!L', seqnum))

    def sendDisconnect(self, reason, desc):
        self.sendPacket(MSG_DISCONNECT, struct.pack('>L', reason)+NS(desc)+NS(''))
        log.msg('Disconnecting with error, code %s\nreason: %s'%(reason, desc))
        self.transport.loseConnection()

    # client methods
    def receiveError(self, reasonCode, description):
        log.msg('Got remote error, code %s\nreason: %s'%(reasonCode, description))

    def receiveUnimplemented(self, seqnum):
        log.msg('other side unimplemented packet #%s'%seqnum)

    def receiveDebug(self, alwaysDisplay, message, lang):
        if alwaysDisplay:
            log.msg('Remote Debug Message:', message)

    def isEncrypted(self, direction = "out"):
        """direction must be in ["out", "in", "both"]
        """
        if self.currentEncryptions == None:
            return 0
        elif direction == "out":
            return self.currentEncryptions.outCip != None
        elif direction == "in":
            return self.currentEncryptions.outCip != None
        elif direction == "both":
            return self.isEncrypted("in")and self.isEncrypted("out")
        else:
            raise TypeError, 'direction must be "out", "in", or "both"'

    def isVerified(self, direction = "out"):
        """direction must be in ["out", "in", "both"]
        """
        if self.currentEncryptions == None:
            return 0
        elif direction == "out":
            return self.currentEncryptions.outMAC != None
        elif direction == "in":
            return self.currentEncryptions.outCMAC != None
        elif direction == "both":
            return self.isVerified("in")and self.isVerified("out")
        else:
            raise TypeError, 'direction must be "out", "in", or "both"'

class SSHServerTransport(SSHTransportBase):
    isClient = 0
    def ssh_KEXINIT(self, packet):
        self.clientKexInitPayload = chr(MSG_KEXINIT)+packet
        #cookie = packet[: 16] # taking this is useless
        k = getNS(packet[16:], 10)
        strings, rest = k[:-1], k[-1]
        kexAlgs, keyAlgs, encCS, encSC, macCS, macSC, compCS, compSC, langCS, langSC =  \
           [s.split(',')for s in strings]
        if ord(rest[0]): # first_kex_packet_follows
            if kexAlgs[0] != self.supportedKeyExchanges[0]or \
               keyAlgs[0] != self.supportedPublicKeys[0]or \
               not ffs(encSC, self.supportedCiphers)or \
               not ffs(encCS, self.supportedCiphers)or \
               not ffs(macSC, self.supportedMACs)or \
               not ffs(macCS, self.supportedMACs)or \
               not ffs(compCS, self.supportedCompressions)or \
               not ffs(compSC, self.supportedCompressions):
                self.ignoreNextPacket = 1 # guess was wrong
        self.kexAlg = ffs(kexAlgs, self.supportedKeyExchanges)
        self.keyAlg = ffs(keyAlgs, self.supportedPublicKeys)
        self.nextEncryptions = SSHCiphers(
        ffs(encSC, self.supportedCiphers), 
            ffs(encCS, self.supportedCiphers), 
            ffs(macSC, self.supportedMACs), 
            ffs(macCS, self.supportedMACs), 
         )
        self.outgoingCompressionType = ffs(compSC, self.supportedCompressions)
        self.incomingCompressionType = ffs(compCS, self.supportedCompressions)
        if None in(self.kexAlg, self.keyAlg, self.outgoingCompressionType, self.incomingCompressionType):
            self.sendDisconnect(DISCONNECT_KEY_EXCHANGE_FAILED, "couldn't match all kex parts")
            return
        if None in self.nextEncryptions.__dict__.values():
            self.sendDisconnect(DISCONNECT_KEY_EXCHANGE_FAILED, "couldn't match all kex parts")
            return
        log.msg('kex alg, key alg: %s %s'%(self.kexAlg, self.keyAlg))
        log.msg('server->client: %s %s %s'%(self.nextEncryptions.outCipType, 
                                            self.nextEncryptions.outMacType, 
                                            self.outgoingCompressionType))
        log.msg('client->server: %s %s %s'%(self.nextEncryptions.inCipType, 
                                            self.nextEncryptions.inMacType, 
                                            self.incomingCompressionType))


    def ssh_KEX_DH_GEX_REQUEST_OLD(self, packet):
        if self.ignoreNextPacket:
            self.ignoreNextPacket = 0
            return
        if self.kexAlg == 'diffie-hellman-group1-sha1': # this is really KEXDH_INIT
            clientDHPubKey, foo = getMP(packet)
            y = Util.number.getRandomNumber(16, entropy.get_bytes)
            f = pow(DH_GENERATOR, y, DH_PRIME)
            sharedSecret = _MPpow(clientDHPubKey, y, DH_PRIME)
            h = sha.new()
            h.update(NS(self.otherVersionString))
            h.update(NS(self.ourVersionString))
            h.update(NS(self.clientKexInitPayload))
            h.update(NS(self.ourKexInitPayload))
            h.update(NS(self.factory.publicKeys[self.keyAlg]))
            h.update(MP(clientDHPubKey))
            h.update(MP(f))
            h.update(sharedSecret)
            exchangeHash = h.digest()
            self.sendPacket(MSG_KEXDH_REPLY, NS(self.factory.publicKeys[self.keyAlg])+ \
                           MP(f)+NS(keys.signData(self.factory.privateKeys[self.keyAlg], exchangeHash)))
            self._keySetup(sharedSecret, exchangeHash)
        elif self.kexAlg == 'diffie-hellman-group-exchange-sha1':
            self.kexAlg = 'diffie-helmman-group-exchange-sha1-old'
            self.ideal = struct.unpack('>L', packet)[0]
            self.g, self.p = self.factory.getDHPrime(self.ideal)
            self.sendPacket(MSG_KEX_DH_GEX_GROUP, MP(self.p)+MP(self.g))
        else:
            raise error.ConchError('bad kexalg: %s'%self.kexAlg)

    def ssh_KEX_DH_GEX_REQUEST(self, packet):
        if self.ignoreNextPacket:
            self.ignoreNextPacket = 0
            return
        self.min, self.ideal, self.max = struct.unpack('>3L', packet)
        self.g, self.p = self.factory.getDHPrime(self.ideal)
        self.sendPacket(MSG_KEX_DH_GEX_GROUP, MP(self.p)+MP(self.g))

    def ssh_KEX_DH_GEX_INIT(self, packet):
        clientDHPubKey, foo = getMP(packet)
        y = Util.number.getRandomNumber(16, entropy.get_bytes)
        f = pow(self.g, y, self.p)
        sharedSecret = _MPpow(clientDHPubKey, y, self.p)
        h = sha.new()
        h.update(NS(self.otherVersionString))
        h.update(NS(self.ourVersionString))
        h.update(NS(self.clientKexInitPayload))
        h.update(NS(self.ourKexInitPayload))
        h.update(NS(self.factory.publicKeys[self.keyAlg]))
        if self.kexAlg == 'diffie-hellman-group-exchange-sha1':
            h.update(struct.pack('>3L', self.min, self.ideal, self.max))
        else:
            h.update(struct.pack('>L', self.ideal))
        h.update(MP(self.p))
        h.update(MP(self.g))
        h.update(MP(clientDHPubKey))
        h.update(MP(f))
        h.update(sharedSecret)
        exchangeHash = h.digest()
        self.sendPacket(MSG_KEX_DH_GEX_REPLY, NS(self.factory.publicKeys[self.keyAlg])+ \
                       MP(f)+NS(keys.signData(self.factory.privateKeys[self.keyAlg], exchangeHash)))
        self._keySetup(sharedSecret, exchangeHash)

    def ssh_NEWKEYS(self, packet):
        if packet != '':
            self.sendDisconnect(DISCONNECT_PROTOCOL_ERROR, "NEWKEYS takes no data")
        self.currentEncryptions = self.nextEncryptions
        if self.outgoingCompressionType == 'zlib':
            self.outgoingCompression = zlib.compressobj(6)
            #self.outgoingCompression.compress = lambda x: self.outgoingCompression.compress(x) + self.outgoingCompression.flush(zlib.Z_SYNC_FLUSH)
        if self.incomingCompressionType == 'zlib':
            self.incomingCompression = zlib.decompressobj()

    def ssh_SERVICE_REQUEST(self, packet):
        service, rest = getNS(packet)
        cls = self.factory.getService(self, service)
        if not cls:
            self.sendDisconnect(DISCONNECT_SERVICE_NOT_AVAILABLE, "don't have service %s"%service)
            return
        else:
            self.sendPacket(MSG_SERVICE_ACCEPT, NS(service))
            self.setService(cls())

    def _keySetup(self, sharedSecret, exchangeHash):
        if not self.sessionID:
            self.sessionID = exchangeHash
        initIVCS = self._getKey('A', sharedSecret, exchangeHash)
        initIVSC = self._getKey('B', sharedSecret, exchangeHash)
        encKeyCS = self._getKey('C', sharedSecret, exchangeHash)
        encKeySC = self._getKey('D', sharedSecret, exchangeHash)
        integKeyCS = self._getKey('E', sharedSecret, exchangeHash)
        integKeySC = self._getKey('F', sharedSecret, exchangeHash)
        self.nextEncryptions.setKeys(initIVSC, encKeySC, initIVCS, encKeyCS, integKeySC, integKeyCS)
        self.sendPacket(MSG_NEWKEYS, '')

    def _getKey(self, c, sharedSecret, exchangeHash):
        k1 = sha.new(sharedSecret+exchangeHash+c+self.sessionID).digest()
        k2 = sha.new(sharedSecret+exchangeHash+k1).digest()
        return k1+k2

class SSHClientTransport(SSHTransportBase):
    isClient = 1

    def connectionMade(self):
        SSHTransportBase.connectionMade(self)
        self._gotNewKeys = 0

    def ssh_KEXINIT(self, packet):
        self.serverKexInitPayload = chr(MSG_KEXINIT)+packet
        #cookie = packet[: 16] # taking this is unimportant
        k = getNS(packet[16:], 10)
        strings, rest = k[:-1], k[-1]
        kexAlgs, keyAlgs, encCS, encSC, macCS, macSC, compCS, compSC, langCS, langSC =  \
           [s.split(',')for s in strings]
        self.kexAlg = ffs(self.supportedKeyExchanges, kexAlgs)
        self.keyAlg = ffs(self.supportedPublicKeys, keyAlgs)
        self.nextEncryptions = SSHCiphers(
        ffs(self.supportedCiphers, encCS), 
            ffs(self.supportedCiphers, encSC), 
            ffs(self.supportedMACs, macCS), 
            ffs(self.supportedMACs, macSC), 
         )
        self.outgoingCompressionType = ffs(self.supportedCompressions, compCS)
        self.incomingCompressionType = ffs(self.supportedCompressions, compSC)
        if None in(self.kexAlg, self.keyAlg, self.outgoingCompressionType, self.incomingCompressionType):
            self.sendDisconnect(DISCONNECT_KEY_EXCHANGE_FAILED, "couldn't match all kex parts")
            return
        if None in self.nextEncryptions.__dict__.values():
            self.sendDisconnect(DISCONNECT_KEY_EXCHANGE_FAILED, "couldn't match all kex parts")
            return
        log.msg('kex alg, key alg: %s %s'%(self.kexAlg, self.keyAlg))
        log.msg('client->server: %s %s %s'%(self.nextEncryptions.outCipType, 
                                            self.nextEncryptions.outMacType, 
                                            self.outgoingCompressionType))
        log.msg('server->client: %s %s %s'%(self.nextEncryptions.inCipType, 
                                            self.nextEncryptions.inMacType, 
                                            self.incomingCompressionType))

        if self.kexAlg == 'diffie-hellman-group1-sha1':
            self.x = Util.number.getRandomNumber(512, entropy.get_bytes)
            self.DHpubKey = pow(DH_GENERATOR, self.x, DH_PRIME)
            self.sendPacket(MSG_KEXDH_INIT, MP(self.DHpubKey))
        else:
            self.sendPacket(MSG_KEX_DH_GEX_REQUEST_OLD, '\x00\x00\x08\x00')

    def ssh_KEX_DH_GEX_GROUP(self, packet):
        if self.kexAlg == 'diffie-hellman-group1-sha1':
            pubKey, packet = getNS(packet)
            f, packet = getMP(packet)
            signature, packet = getNS(packet)
            fingerprint = ':'.join(map(lambda c: '%02x'%ord(c), md5.new(pubKey).digest()))
            d = self.verifyHostKey(pubKey, fingerprint)
            d.addCallback(self._continueGEX_GROUP, pubKey, f, signature)
            d.addErrback(lambda unused,self=self:self.sendDisconnect(DISCONNECT_HOST_KEY_NOT_VERIFIABLE, 'bad host key'))
        else:
            self.p, rest = getMP(packet)
            self.g, rest = getMP(rest)
            self.x = getMP('\x00\x00\x00\x40'+entropy.get_bytes(64))[0]
            self.DHpubKey = pow(self.g, self.x, self.p)
            self.sendPacket(MSG_KEX_DH_GEX_INIT, MP(self.DHpubKey))

    def _continueGEX_GROUP(self, ignored, pubKey, f, signature):
        serverKey = keys.getPublicKeyObject(pubKey)
        sharedSecret = _MPpow(f, self.x, DH_PRIME)
        h = sha.new()
        h.update(NS(self.ourVersionString))
        h.update(NS(self.otherVersionString))
        h.update(NS(self.ourKexInitPayload))
        h.update(NS(self.serverKexInitPayload))
        h.update(NS(pubKey))
        h.update(MP(self.DHpubKey))
        h.update(MP(f))
        h.update(sharedSecret)
        exchangeHash = h.digest()
        if not keys.verifySignature(serverKey, signature, exchangeHash):
            self.sendDisconnect(DISCONNECT_KEY_EXCHANGE_FAILED, 'bad signature')
            return
        self._keySetup(sharedSecret, exchangeHash)

    def ssh_KEX_DH_GEX_REPLY(self, packet):
        pubKey, packet = getNS(packet)
        f, packet = getMP(packet)
        signature, packet = getNS(packet)
        fingerprint = ':'.join(map(lambda c: '%02x'%ord(c), md5.new(pubKey).digest()))
        d = self.verifyHostKey(pubKey, fingerprint)
        d.addCallback(self._continueGEX_REPLY, pubKey, f, signature)
        d.addErrback(lambda unused, self=self: self.sendDisconnect(DISCONNECT_HOST_KEY_NOT_VERIFIABLE, 'bad host key'))

    def _continueGEX_REPLY(self, ignored, pubKey, f, signature):
        serverKey = keys.getPublicKeyObject(pubKey)
        sharedSecret = _MPpow(f, self.x, self.p)
        h = sha.new()
        h.update(NS(self.ourVersionString))
        h.update(NS(self.otherVersionString))
        h.update(NS(self.ourKexInitPayload))
        h.update(NS(self.serverKexInitPayload))
        h.update(NS(pubKey))
        h.update('\x00\x00\x08\x00')
        h.update(MP(self.p))
        h.update(MP(self.g))
        h.update(MP(self.DHpubKey))
        h.update(MP(f))
        h.update(sharedSecret)
        exchangeHash = h.digest()
        if not keys.verifySignature(serverKey, signature, exchangeHash):
            self.sendDisconnect(DISCONNECT_KEY_EXCHANGE_FAILED, 'bad signature')
            return
        self._keySetup(sharedSecret, exchangeHash)

    def _keySetup(self, sharedSecret, exchangeHash):
        if not self.sessionID:
            self.sessionID = exchangeHash
        initIVCS = self._getKey('A', sharedSecret, exchangeHash)
        initIVSC = self._getKey('B', sharedSecret, exchangeHash)
        encKeyCS = self._getKey('C', sharedSecret, exchangeHash)
        encKeySC = self._getKey('D', sharedSecret, exchangeHash)
        integKeyCS = self._getKey('E', sharedSecret, exchangeHash)
        integKeySC = self._getKey('F', sharedSecret, exchangeHash)
        self.nextEncryptions.setKeys(initIVCS, encKeyCS, initIVSC, encKeySC, integKeyCS, integKeySC)
        self.sendPacket(MSG_NEWKEYS, '')
        if self._gotNewKeys:
            self.ssh_NEWKEYS('')

    def _getKey(self, c, sharedSecret, exchangeHash):
        k1 = sha.new(sharedSecret+exchangeHash+c+self.sessionID).digest()
        k2 = sha.new(sharedSecret+exchangeHash+k1).digest()
        return k1+k2

    def ssh_NEWKEYS(self, packet):
        if packet != '':
            self.sendDisconnect(DISCONNECT_PROTOCOL_ERROR, "NEWKEYS takes no data")
        if not hasattr(self.nextEncryptions, 'outCip'):
            self._gotNewKeys = 1
            return
        self.currentEncryptions = self.nextEncryptions
        if self.outgoingCompressionType == 'zlib':
            self.outgoingCompression = zlib.compressobj(6)
            #self.outgoingCompression.compress = lambda x: self.outgoingCompression.compress(x) + self.outgoingCompression.flush(zlib.Z_SYNC_FLUSH)
        if self.incomingCompressionType == 'zlib':
            self.incomingCompression = zlib.decompressobj()
        self.connectionSecure()

    def ssh_SERVICE_ACCEPT(self, packet):
        name = getNS(packet)[0]
        if name != self.instance.name:
            self.sendDisconnect(DISCONNECT_PROTOCOL_ERROR, "received accept for service we did not request")
        self.setService(self.instance)

    def requestService(self, instance):
        """
        Request that a service be run over this transport.

        @type instance: subclass of C{twisted.conch.ssh.service.SSHService}
        """
        self.sendPacket(MSG_SERVICE_REQUEST, NS(instance.name))
        self.instance = instance

    # client methods
    def verifyHostKey(self, hostKey, fingerprint):
        """Returns a Deferred that gets a callback if it is a valid key, or
        an errback if not.

        @type hostKey:      C{str}
        @type fingerprint:  C{str}
        @rtype:             C{Deferred}
        """
        # return  if it's good
        return defer.fail(NotImplementedError)

    def connectionSecure(self):
        """
        Called when the encryption has been set up.  Generally, 
        requestService() is called to run another service over the transport.
        """
        raise NotImplementedError

class SSHCiphers:
    cipherMap = {
        '3des-cbc':('DES3', 24), 
        'blowfish-cbc':('Blowfish', 16), 
        'aes256-cbc':('AES', 32), 
        'aes192-cbc':('AES', 24), 
        'aes128-cbc':('AES', 16), 
        'arcfour':('ARC4', 16), 
        'idea-cbc':('IDEA', 16), 
        'cast128-cbc':('CAST', 16), 
        'none':(None, None), 
     }
    macMap = {
        'hmac-sha1': 'sha', 
        'hmac-md5': 'md5', 
        'none': None, 
     }

    def __init__(self, outCip, inCip, outMac, inMac):
        self.outCipType = outCip
        self.inCipType = inCip
        self.outMacType = outMac
        self.inMacType = inMac

    def setKeys(self, outIV, outKey, inIV, inKey, outInteg, inInteg):
        self.outCip = self._getCipher(self.outCipType, outIV, outKey)
        self.enc_block_size = self.outCip.block_size
        self.inCip = self._getCipher(self.inCipType, inIV, inKey)
        self.dec_block_size = self.inCip.block_size
        self.outMAC = self._getMAC(self.outMacType, outInteg)
        self.inMAC = self._getMAC(self.inMacType, inInteg)
        self.verify_digest_size = self.inMAC[2]

    def _getCipher(self, cip, iv, key):
        modName, keySize = self.cipherMap[cip]
        if not modName: return # no cipher
        mod = __import__('Crypto.Cipher.%s'%modName, {}, {}, 'x')
        return mod.new(key[: keySize], mod.MODE_CBC, iv[: mod.block_size])

    def _getMAC(self, mac, key):
        modName = self.macMap[mac]
        if not modName: return
        mod = __import__(modName, {}, {}, '')
        if not hasattr(mod, 'digest_size'):
            ds = len(mod.new().digest())
        else:
            ds = mod.digest_size
        key = key[: ds]+'\x00'*(64-ds)
        return mod, key, ds

    def encrypt(self, blocks):
        return self.outCip and self.outCip.encrypt(blocks) or blocks

    def decrypt(self, blocks):
        return self.inCip and self.inCip.decrypt(blocks) or blocks

    def makeMAC(self, seqid, data):
        data = struct.pack('>L', seqid)+data
        mod, key, ds = self.outMAC
        inner = mod.new(XOR.new('\x36').encrypt(key)+data)
        outer = mod.new(XOR.new('\x5c').encrypt(key)+inner.digest())
        return outer.digest()

    def verify(self, seqid, data, mac):
        data = struct.pack('>L', seqid)+data
        mod, key, ds = self.inMAC
        inner = mod.new(XOR.new('\x36').encrypt(key)+data)
        outer = mod.new(XOR.new('\x5c').encrypt(key)+inner.digest())
        return mac == outer.digest()

def buffer_dump(b, title = None):
    r = title or ''
    while b:
        c, b = b[: 16], b[16:]
        while c:
            a, c = c[: 2], c[2:]
            if len(a) == 2:
                r = r+'%02x%02x '%(ord(a[0]), ord(a[1]))
            else:
                r = r+'%02x'%ord(a[0])
        r = r+'\n'
    return r

DH_PRIME = 179769313486231590770839156793787453197860296048756011706444423684197180216158519368947833795864925541502180565485980503646440548199239100050792877003355816639229553136239076508735759914822574862575007425302077447712589550957937778424442426617334727629299387668709205606050270810842907692932019128194467627007L
DH_GENERATOR = 2L

MSG_DISCONNECT = 1
MSG_IGNORE = 2
MSG_UNIMPLEMENTED = 3
MSG_DEBUG = 4
MSG_SERVICE_REQUEST = 5
MSG_SERVICE_ACCEPT = 6
MSG_KEXINIT = 20
MSG_NEWKEYS = 21
MSG_KEXDH_INIT = 30
MSG_KEXDH_REPLY = 31
MSG_KEX_DH_GEX_REQUEST_OLD = 30
MSG_KEX_DH_GEX_REQUEST = 34
MSG_KEX_DH_GEX_GROUP = 31
MSG_KEX_DH_GEX_INIT = 32
MSG_KEX_DH_GEX_REPLY = 33

DISCONNECT_HOST_NOT_ALLOWED_TO_CONNECT = 1
DISCONNECT_PROTOCOL_ERROR = 2
DISCONNECT_KEY_EXCHANGE_FAILED = 3
DISCONNECT_RESERVED = 4
DISCONNECT_MAC_ERROR = 5
DISCONNECT_COMPRESSION_ERROR = 6
DISCONNECT_SERVICE_NOT_AVAILABLE = 7
DISCONNECT_PROTOCOL_VERSION_NOT_SUPPORTED = 8
DISCONNECT_HOST_KEY_NOT_VERIFIABLE = 9
DISCONNECT_CONNECTION_LOST = 10
DISCONNECT_BY_APPLICATION = 11
DISCONNECT_TOO_MANY_CONNECTIONS = 12
DISCONNECT_AUTH_CANCELLED_BY_USER = 13
DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE = 14
DISCONNECT_ILLEGAL_USER_NAME = 15

messages = {}
import transport
for v in dir(transport):
    if v[: 4] == 'MSG_':
        messages[getattr(transport, v)] = v # doesn't handle doubles

www.java2java.com | Contact Us
Copyright 2009 - 12 Demo Source and Support. All rights reserved.
All other trademarks are property of their respective owners.