// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2022 Meta Platforms, Inc. and affiliates. */

#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include "bpf_misc.h"
#include "test_user_ringbuf.h"

char _license[] SEC("license") = "GPL";

struct {
	__uint(type, BPF_MAP_TYPE_USER_RINGBUF);
} user_ringbuf SEC(".maps");

struct {
	__uint(type, BPF_MAP_TYPE_RINGBUF);
} kernel_ringbuf SEC(".maps");

/* inputs */
int pid, err, val;

int read = 0;

/* Counter used for end-to-end protocol test */
__u64 kern_mutated = 0;
__u64 user_mutated = 0;
__u64 expected_user_mutated = 0;

static int
is_test_process(void)
{
	int cur_pid = bpf_get_current_pid_tgid() >> 32;

	return cur_pid == pid;
}

static long
record_sample(struct bpf_dynptr *dynptr, void *context)
{
	const struct sample *sample = NULL;
	struct sample stack_sample;
	int status;
	static int num_calls;

	if (num_calls++ % 2 == 0) {
		status = bpf_dynptr_read(&stack_sample, sizeof(stack_sample), dynptr, 0, 0);
		if (status) {
			bpf_printk("bpf_dynptr_read() failed: %d\n", status);
			err = 1;
			return 1;
		}
	} else {
		sample = bpf_dynptr_data(dynptr, 0, sizeof(*sample));
		if (!sample) {
			bpf_printk("Unexpectedly failed to get sample\n");
			err = 2;
			return 1;
		}
		stack_sample = *sample;
	}

	__sync_fetch_and_add(&read, 1);
	return 0;
}

static void
handle_sample_msg(const struct test_msg *msg)
{
	switch (msg->msg_op) {
	case TEST_MSG_OP_INC64:
		kern_mutated += msg->operand_64;
		break;
	case TEST_MSG_OP_INC32:
		kern_mutated += msg->operand_32;
		break;
	case TEST_MSG_OP_MUL64:
		kern_mutated *= msg->operand_64;
		break;
	case TEST_MSG_OP_MUL32:
		kern_mutated *= msg->operand_32;
		break;
	default:
		bpf_printk("Unrecognized op %d\n", msg->msg_op);
		err = 2;
	}
}

static long
read_protocol_msg(struct bpf_dynptr *dynptr, void *context)
{
	const struct test_msg *msg = NULL;

	msg = bpf_dynptr_data(dynptr, 0, sizeof(*msg));
	if (!msg) {
		err = 1;
		bpf_printk("Unexpectedly failed to get msg\n");
		return 0;
	}

	handle_sample_msg(msg);

	return 0;
}

static int publish_next_kern_msg(__u32 index, void *context)
{
	struct test_msg *msg = NULL;
	int operand_64 = TEST_OP_64;
	int operand_32 = TEST_OP_32;

	msg = bpf_ringbuf_reserve(&kernel_ringbuf, sizeof(*msg), 0);
	if (!msg) {
		err = 4;
		return 1;
	}

	switch (index % TEST_MSG_OP_NUM_OPS) {
	case TEST_MSG_OP_INC64:
		msg->operand_64 = operand_64;
		msg->msg_op = TEST_MSG_OP_INC64;
		expected_user_mutated += operand_64;
		break;
	case TEST_MSG_OP_INC32:
		msg->operand_32 = operand_32;
		msg->msg_op = TEST_MSG_OP_INC32;
		expected_user_mutated += operand_32;
		break;
	case TEST_MSG_OP_MUL64:
		msg->operand_64 = operand_64;
		msg->msg_op = TEST_MSG_OP_MUL64;
		expected_user_mutated *= operand_64;
		break;
	case TEST_MSG_OP_MUL32:
		msg->operand_32 = operand_32;
		msg->msg_op = TEST_MSG_OP_MUL32;
		expected_user_mutated *= operand_32;
		break;
	default:
		bpf_ringbuf_discard(msg, 0);
		err = 5;
		return 1;
	}

	bpf_ringbuf_submit(msg, 0);

	return 0;
}

static void
publish_kern_messages(void)
{
	if (expected_user_mutated != user_mutated) {
		bpf_printk("%lu != %lu\n", expected_user_mutated, user_mutated);
		err = 3;
		return;
	}

	bpf_loop(8, publish_next_kern_msg, NULL, 0);
}

SEC("fentry/" SYS_PREFIX "sys_prctl")
int test_user_ringbuf_protocol(void *ctx)
{
	long status = 0;

	if (!is_test_process())
		return 0;

	status = bpf_user_ringbuf_drain(&user_ringbuf, read_protocol_msg, NULL, 0);
	if (status < 0) {
		bpf_printk("Drain returned: %ld\n", status);
		err = 1;
		return 0;
	}

	publish_kern_messages();

	return 0;
}

SEC("fentry/" SYS_PREFIX "sys_getpgid")
int test_user_ringbuf(void *ctx)
{
	if (!is_test_process())
		return 0;

	err = bpf_user_ringbuf_drain(&user_ringbuf, record_sample, NULL, 0);

	return 0;
}

static long
do_nothing_cb(struct bpf_dynptr *dynptr, void *context)
{
	__sync_fetch_and_add(&read, 1);
	return 0;
}

SEC("fentry/" SYS_PREFIX "sys_prlimit64")
int test_user_ringbuf_epoll(void *ctx)
{
	long num_samples;

	if (!is_test_process())
		return 0;

	num_samples = bpf_user_ringbuf_drain(&user_ringbuf, do_nothing_cb, NULL, 0);
	if (num_samples <= 0)
		err = 1;

	return 0;
}