// SPDX-License-Identifier: GPL-2.0
// Copyright (c) 2018 Facebook

#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include <arpa/inet.h>
#include <net/if.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/types.h>


#include <bpf/bpf.h>
#include <bpf/libbpf.h>

#include "cgroup_helpers.h"

#define CGROUP_PATH		"/skb_cgroup_test"
#define NUM_CGROUP_LEVELS	4

/* RFC 4291, Section 2.7.1 */
#define LINKLOCAL_MULTICAST	"ff02::1"

static int mk_dst_addr(const char *ip, const char *iface,
		       struct sockaddr_in6 *dst)
{
	memset(dst, 0, sizeof(*dst));

	dst->sin6_family = AF_INET6;
	dst->sin6_port = htons(1025);

	if (inet_pton(AF_INET6, ip, &dst->sin6_addr) != 1) {
		log_err("Invalid IPv6: %s", ip);
		return -1;
	}

	dst->sin6_scope_id = if_nametoindex(iface);
	if (!dst->sin6_scope_id) {
		log_err("Failed to get index of iface: %s", iface);
		return -1;
	}

	return 0;
}

static int send_packet(const char *iface)
{
	struct sockaddr_in6 dst;
	char msg[] = "msg";
	int err = 0;
	int fd = -1;

	if (mk_dst_addr(LINKLOCAL_MULTICAST, iface, &dst))
		goto err;

	fd = socket(AF_INET6, SOCK_DGRAM, 0);
	if (fd == -1) {
		log_err("Failed to create UDP socket");
		goto err;
	}

	if (sendto(fd, &msg, sizeof(msg), 0, (const struct sockaddr *)&dst,
		   sizeof(dst)) == -1) {
		log_err("Failed to send datagram");
		goto err;
	}

	goto out;
err:
	err = -1;
out:
	if (fd >= 0)
		close(fd);
	return err;
}

int get_map_fd_by_prog_id(int prog_id)
{
	struct bpf_prog_info info = {};
	__u32 info_len = sizeof(info);
	__u32 map_ids[1];
	int prog_fd = -1;
	int map_fd = -1;

	prog_fd = bpf_prog_get_fd_by_id(prog_id);
	if (prog_fd < 0) {
		log_err("Failed to get fd by prog id %d", prog_id);
		goto err;
	}

	info.nr_map_ids = 1;
	info.map_ids = (__u64) (unsigned long) map_ids;

	if (bpf_prog_get_info_by_fd(prog_fd, &info, &info_len)) {
		log_err("Failed to get info by prog fd %d", prog_fd);
		goto err;
	}

	if (!info.nr_map_ids) {
		log_err("No maps found for prog fd %d", prog_fd);
		goto err;
	}

	map_fd = bpf_map_get_fd_by_id(map_ids[0]);
	if (map_fd < 0)
		log_err("Failed to get fd by map id %d", map_ids[0]);
err:
	if (prog_fd >= 0)
		close(prog_fd);
	return map_fd;
}

int check_ancestor_cgroup_ids(int prog_id)
{
	__u64 actual_ids[NUM_CGROUP_LEVELS], expected_ids[NUM_CGROUP_LEVELS];
	__u32 level;
	int err = 0;
	int map_fd;

	expected_ids[0] = get_cgroup_id("/..");	/* root cgroup */
	expected_ids[1] = get_cgroup_id("");
	expected_ids[2] = get_cgroup_id(CGROUP_PATH);
	expected_ids[3] = 0; /* non-existent cgroup */

	map_fd = get_map_fd_by_prog_id(prog_id);
	if (map_fd < 0)
		goto err;

	for (level = 0; level < NUM_CGROUP_LEVELS; ++level) {
		if (bpf_map_lookup_elem(map_fd, &level, &actual_ids[level])) {
			log_err("Failed to lookup key %d", level);
			goto err;
		}
		if (actual_ids[level] != expected_ids[level]) {
			log_err("%llx (actual) != %llx (expected), level: %u\n",
				actual_ids[level], expected_ids[level], level);
			goto err;
		}
	}

	goto out;
err:
	err = -1;
out:
	if (map_fd >= 0)
		close(map_fd);
	return err;
}

int main(int argc, char **argv)
{
	int cgfd = -1;
	int err = 0;

	if (argc < 3) {
		fprintf(stderr, "Usage: %s iface prog_id\n", argv[0]);
		exit(EXIT_FAILURE);
	}

	/* Use libbpf 1.0 API mode */
	libbpf_set_strict_mode(LIBBPF_STRICT_ALL);

	cgfd = cgroup_setup_and_join(CGROUP_PATH);
	if (cgfd < 0)
		goto err;

	if (send_packet(argv[1]))
		goto err;

	if (check_ancestor_cgroup_ids(atoi(argv[2])))
		goto err;

	goto out;
err:
	err = -1;
out:
	close(cgfd);
	cleanup_cgroup_environment();
	printf("[%s]\n", err ? "FAIL" : "PASS");
	return err;
}