/*
 * Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms and conditions of the GNU General Public License,
 * version 2, as published by the Free Software Foundation.
 *
 * This program is distributed in the hope it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
 * more details.
 */

#include <vmlinux.h>

#include <bpf/bpf_core_read.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_tracing.h>

#include "syscall_trace.h"

struct tls_data {
    u64 timestamp_start;
    u64 syscall_nr;
    u32 tgid;
    u32 pid;
    u16 seq_nr;
};

/*
 * To filter in containers, uniquely identify the PID namespace of the processes
 * to trace.
 * https://lore.kernel.org/bpf/20200304204157.58695-1-cneirabustos@gmail.com/
 * The following two variables must be set before loading the BPF programs.
 */
volatile const __u64 allowed_pidns_dev;
volatile const __u64 allowed_pidns_ino;

/*
 * The following map is used to store the TGIDs (as seen from the default PID
 * namespace) of the processes that need to be traced. The field `max_entries`
 * must be set before loading the BPF programs. After the BPF programs have been
 * loaded, this map must be updated with the TGIDs (as seen from the default PID
 * namespace) of the processes that need to be traced.
 */
SEC(".maps")
struct {
    __uint(type, BPF_MAP_TYPE_HASH);
    __type(key, u32);
    __type(value, u32);
} allowed_tgids;

/*
 * The following map is used to store information for each desired `task_struct`
 * object in local storage. The BPF iterator `nsys_delete_tls` must be triggered
 * before (re)attaching the programs `nsys_trace_sys_enter` and
 * `nsys_trace_sys_exit`. This will ensure that there is no stale data in the
 * local storage from any previous attachments of the BPF programs
 * `nsys_trace_sys_enter` and `nsys_trace_sys_exit`.
 */
SEC(".maps")
struct {
    __uint(type, BPF_MAP_TYPE_TASK_STORAGE);
    __uint(map_flags, BPF_F_NO_PREALLOC);
    __type(key, int);
    __type(value, struct tls_data);
} tls_map;

/*
 * The following map is a ring buffer that is used to transfer trace data
 * records from the kernel space to the user space. The field `max_entries` must
 * be set before loading the BPF programs.
 */
SEC(".maps")
struct {
    __uint(type, BPF_MAP_TYPE_RINGBUF);
} ring_buffer;

static bool allow_current_task(struct bpf_pidns_info *nsdata)
{
    if (bpf_get_ns_current_pid_tgid(allowed_pidns_dev, allowed_pidns_ino,
                                    nsdata, sizeof(*nsdata)))
        return 0;

    return bpf_map_lookup_elem(&allowed_tgids, &nsdata->tgid);
}

static struct pid_namespace *get_pidns(struct task_struct *task)
{
    struct pid *pid = BPF_CORE_READ(task, thread_pid);
    unsigned int level = BPF_CORE_READ(pid, level);

    return BPF_CORE_READ(pid, numbers[level].ns);
}

static int get_tgid_vnr(struct task_struct *task)
{
    struct pid *pid = BPF_CORE_READ(task, signal, pids[PIDTYPE_TGID]);
    unsigned int level = BPF_CORE_READ(pid, level);

    return BPF_CORE_READ(pid, numbers[level].nr);
}

/*
 * The following program is used to track forks from the processes that need to
 * be traced. This program must be attached at all times, even if tracing is not
 * being done. This will ensure that the BPF map `allowed_tgids` is kept up to
 * date.
 */
SEC("tp_btf/sched_process_fork")
int BPF_PROG(nsys_track_sched_process_fork, struct task_struct *parent,
                                            struct task_struct *child)
{
    struct bpf_pidns_info nsdata;
    long err;
    u32 tgid;

    /* The current task is `parent` */
    if (!allow_current_task(&nsdata) || get_pidns(child) != get_pidns(parent))
        goto out;

    tgid = get_tgid_vnr(child);
    if (tgid == nsdata.tgid)
        goto out;

    err = bpf_map_update_elem(&allowed_tgids, &tgid, &tgid, BPF_NOEXIST);
    if (err) {
        bpf_printk("%s: bpf_map_update_elem() error: %ld", __func__, -err);
        goto out;
    }

out:
    return 0;
}

/*
 * The following program is used to track the exits of the processes that need
 * to be traced. This program must be attached at all times, even if tracing is
 * not being done. This will ensure that the BPF map `allowed_tgids` is kept up
 * to date.
 */
SEC("tp_btf/sched_process_exit")
int BPF_PROG(nsys_track_sched_process_exit, struct task_struct *task)
{
    struct bpf_pidns_info nsdata;
    long err;

    /* The current task is `task` */
    if (!allow_current_task(&nsdata) || nsdata.tgid != nsdata.pid)
        goto out;

    err = bpf_map_delete_elem(&allowed_tgids, &nsdata.tgid);
    if (err) {
        bpf_printk("%s: bpf_map_delete_elem() error: %ld", __func__, -err);
        goto out;
    }

out:
    return 0;
}

/*
 * The following program is used to trace the `sys_enter` tracepoint.
 */
SEC("tp_btf/sys_enter")
int BPF_PROG(nsys_trace_sys_enter, struct pt_regs *regs, long syscall_nr)
{
    struct task_struct *task = bpf_get_current_task_btf();
    struct bpf_pidns_info nsdata;
    struct tls_data *td;

    td = bpf_task_storage_get(&tls_map, task, NULL, 0);
    if (!td) {
        if (!allow_current_task(&nsdata))
            goto out;

        td = bpf_task_storage_get(&tls_map, task, NULL,
                                  BPF_LOCAL_STORAGE_GET_F_CREATE);
        if (!td) {
            bpf_printk("%s: bpf_task_storage_get() error", __func__);
            goto out;
        }

        td->tgid = nsdata.tgid;
        td->pid = nsdata.pid;
    }

    td->timestamp_start = bpf_ktime_get_ns();
    td->syscall_nr = syscall_nr;
    td->seq_nr += 1;

out:
    return 0;
}

/*
 * The following program is used to trace the `sys_exit` tracepoint. Since trace
 * data records are submitted into the BPF map `ring_buffer` from this program,
 * this program must be detached before detaching the BPF program
 * `nsys_trace_sys_enter`. This will ensure that there is no incorrect data in
 * the BPF map `ring_buffer`.
 */
SEC("tp_btf/sys_exit")
int BPF_PROG(nsys_trace_sys_exit, struct pt_regs *regs, long ret)
{
    struct task_struct *task = bpf_get_current_task_btf();
    struct syscall_trace_data *data;
    struct tls_data *td;

    td = bpf_task_storage_get(&tls_map, task, NULL, 0);
    if (!td)
        goto out;

    data = bpf_ringbuf_reserve(&ring_buffer, sizeof(*data), 0);
    if (!data) {
        bpf_printk("%s: bpf_ringbuf_reserve() error", __func__);
        goto out;
    }

    data->timestamp_start = td->timestamp_start;
    data->timestamp_end = bpf_ktime_get_ns();
    data->syscall_nr = td->syscall_nr;
    data->tgid = td->tgid;
    data->pid = td->pid;
    data->seq_nr = td->seq_nr;

    bpf_ringbuf_submit(data, 0);

out:
    return 0;
}

/*
 * The following program is used to delete the local storage associated with
 * the BPF map `tls_map` from every `task_struct` object.
 */
SEC("iter/task")
int nsys_delete_tls(struct bpf_iter__task *ctx)
{
    struct task_struct *task = ctx->task;
    long err;

    if (!task || !bpf_task_storage_get(&tls_map, task, NULL, 0))
        goto out;

    err = bpf_task_storage_delete(&tls_map, task);
    if (err) {
        bpf_printk("%s: bpf_task_storage_delete() error: %ld", __func__, -err);
        goto out;
    }

out:
    return 0;
}

char LICENSE[] SEC("license") = "GPL v2";
