#!/usr/bin/env python

import json
import random
import argparse

import pygraphviz


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('jsonfile',
                        type=argparse.FileType('r'),
                        help='The JSON file generated by dublin-traceroute')
    parser.add_argument('--no-rtt',
                        action='store_true',
                        default=False,
                        help='Do not add RTT information to the graph '
                             '(default: False)')
    parser.add_argument('--generate-dot',
                        action='store_true',
                        default=False,
                        help='Generate the intermediate DOT file')
    return parser.parse_args()


def json_to_graphviz(traceroute, no_rtt=False):
    graph = pygraphviz.AGraph(strict=False, directed=True)
    graph.node_attr['shape'] = 'ellipse'
    graph.graph_attr['rankdir'] = 'BT'

    # create a dummy first node to add the source host to the graph
    # FIXME this approach sucks
    for flow, hops in traceroute['flows'].iteritems():
        src_ip = hops[0]['sent']['ip']['src']
        firsthop = {}
        hops = [firsthop] + hops
        color = random.randrange(0, 0xffffff)

        previous_nat_id = 0
        for index, hop in enumerate(hops):

            # add node
            if index == 0:
                # first hop, the source host
                nodename = src_ip
                graph.add_node(nodename, shape='rectangle')
            else:
                # all the other hops
                received = hop['received']
                nodeattrs = {}
                if received is None:
                    nodename = 'NULL{idx}'.format(idx=index)
                    nodeattrs['label'] = '*'
                else:
                    nodename = received['ip']['src']
                    if hop['name'] != nodename:
                        hostname = '\n{h}'.format(h=hop['name'])
                    else:
                        hostname = ''
                    nodeattrs['label'] = '{ip}{name}\n{icmp}'.format(
                        ip=nodename,
                        name=hostname,
                        icmp=received['icmp']['description']
                    )
                if index == 0 or hop['is_last']:
                    nodeattrs['shape'] = 'rectangle'
                graph.add_node(nodename)
                graph.get_node(nodename).attr.update(nodeattrs)

            # add edge
            try:
                nexthop = hops[index + 1]
            except IndexError:
                # This means that we are at the last hop, no further edge
                continue

            next_received = nexthop['received']
            edgeattrs = {'color': '#{c:x}'.format(c=color), 'label': ''}
            if next_received is None:
                next_nodename = 'NULL{idx}'.format(idx=index + 1)
            else:
                next_nodename = next_received['ip']['src']
            if index == 0:
                edgeattrs['label'] = 'dport\n{dp}'.format(dp=flow)
            rtt = nexthop['rtt_usec']
            try:
                if previous_nat_id != nexthop['nat_id']:
                    edgeattrs['label'] += '\nNAT detected'
                previous_nat_id = hop['nat_id']
            except KeyError:
                pass
            if not no_rtt:
                if rtt is not None:
                    edgeattrs['label'] += '\n{sec}.{usec} ms'.format(
                        sec=rtt / 1000, usec=rtt % 1000)
            graph.add_edge(nodename, next_nodename)
            graph.get_edge(nodename, next_nodename).attr.update(edgeattrs)

    return graph


def main():
    args = parse_args()
    traceroute = json.load(args.jsonfile)
    graph = json_to_graphviz(traceroute, args.no_rtt)
    print graph
    graph.layout('dot')

    # Save to DOT
    if args.generate_dot:
        dotfile = '{name}.dot'.format(name=args.jsonfile.name)
        graph.write(dotfile)
        print('Generated DOT file: {f}'.format(f=dotfile))

    # Save to PNG
    pngfile = '{name}.png'.format(name=args.jsonfile.name)
    graph.draw(pngfile)
    print('Graph saved to {f}'.format(f=pngfile))

if __name__ == '__main__':
    main()
