#include <cstring>
#include <functional>
#include <iomanip>

#include "gns3/GnsPort.hpp"
#include "gns3/GnsServer.hpp"

#include <iostream>
#include <memory>
#include <pcap.h>
#include <netinet/ether.h>
#include <netinet/udp.h>

#include "utils/address.hpp"


/**
 * Callback of a detected Wake-On-LAN packet
 */
void packet_wol_handler(
    unsigned char *user_data,
    const pcap_pkthdr* header,
    const std::uint8_t* packet
) {
    const auto server = reinterpret_cast<gns::GnsServer*>(user_data);

    // get the special sll header (added by pcap when using the "any" device)
    std::uint16_t sll_header = *packet;
    // get the ethernet header of the packet
    auto* packet_ethernet_header = reinterpret_cast<const ether_header*>(packet + sizeof(sll_header));
    // get the content of the packet
    const std::uint8_t* packet_content_start = packet + sizeof(sll_header) + sizeof(ether_header);
    const std::vector packet_content(packet_content_start, packet_content_start + header->len);

    std::cout << "Captured a WoL packet." << std::endl;
    std::cout << "Packet length: " + std::to_string(header->len) + " bytes" << std::endl;

    // get the source address
    const std::string mac_address_source = utils::address::mac_bytes_to_string(packet_ethernet_header->ether_shost);
    std::cout << "Source: " << mac_address_source << std::endl;

    // TODO(Faraphel): check the magic header ? content[6:12]

    // get the destination address
    const std::string mac_address_target = utils::address::mac_bytes_to_string(packet_content.data() + 6);
    std::cout << "Destination: " << mac_address_target << std::endl;

    // TODO(Faraphel): check the 16 repetitions of the source address ?

    // find the machine with the mac address
    for (const gns::GnsNode& node : server->getNodes())
        for (const gns::GnsPort& port : node.getPorts())
            if (port.mac_address == mac_address_target) {
                std::cout << "Matching node: " + node.getUuid() << std::endl;
                node.start();
                std::cout << "Node started." << std::endl;
                return;
            }

    std::cerr << "Found no matching node." << std::endl;
}

int main() {
    // get the GNS3 server
    auto server = gns::GnsServer("localhost", 80);

    // capture any packet on the selected interface
    char error_buffer[PCAP_ERRBUF_SIZE];
    const auto handle = std::unique_ptr<pcap_t, decltype(&pcap_close)>(
        pcap_open_live("any", BUFSIZ, 1, 1000, error_buffer),  // capture on all interface
        pcap_close
    );

    if (handle == nullptr)
        throw std::runtime_error("pcap_open_live() failed: " + std::string(error_buffer));

    // compile the packet filter
    bpf_program filter {};
    const std::string filter_expression = "ether proto 0x0842";  // only match WoL packets
    if (pcap_compile(handle.get(), &filter, filter_expression.c_str(), 0, PCAP_NETMASK_UNKNOWN) == -1)
        throw std::runtime_error("pcap_compile() failed: " + std::string(error_buffer));

    // apply the filter
    if (pcap_setfilter(handle.get(), &filter) == -1)
        throw std::runtime_error("pcap_setfilter() failed: " + std::string(error_buffer));

    // Start packet capture loop
    if (pcap_loop(
        handle.get(),
        0,
        packet_wol_handler,
        reinterpret_cast<uint8_t*>(&server)
    ) == -1)
        throw std::runtime_error("pcap_loop() failed: " + std::string(error_buffer));

    // TODO(Faraphel): more forgiving exception handling
    // TODO(Faraphel): if possible, check if the two ports are in the same network.

    return 0;
}