// SPDX-License-Identifier: GPL-2.0-only
/****************************************************************************
 * Driver for Solarflare network controllers and boards
 * Copyright 2023, Advanced Micro Devices, Inc.
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 as published
 * by the Free Software Foundation, incorporated herein by reference.
 */

#include "tc_conntrack.h"
#include "tc.h"
#include "mae.h"

static int efx_tc_flow_block(enum tc_setup_type type, void *type_data,
			     void *cb_priv);

static const struct rhashtable_params efx_tc_ct_zone_ht_params = {
	.key_len	= offsetof(struct efx_tc_ct_zone, linkage),
	.key_offset	= 0,
	.head_offset	= offsetof(struct efx_tc_ct_zone, linkage),
};

static const struct rhashtable_params efx_tc_ct_ht_params = {
	.key_len	= offsetof(struct efx_tc_ct_entry, linkage),
	.key_offset	= 0,
	.head_offset	= offsetof(struct efx_tc_ct_entry, linkage),
};

static void efx_tc_ct_zone_free(void *ptr, void *arg)
{
	struct efx_tc_ct_zone *zone = ptr;
	struct efx_nic *efx = zone->efx;

	netif_err(efx, drv, efx->net_dev,
		  "tc ct_zone %u still present at teardown, removing\n",
		  zone->zone);

	nf_flow_table_offload_del_cb(zone->nf_ft, efx_tc_flow_block, zone);
	kfree(zone);
}

static void efx_tc_ct_free(void *ptr, void *arg)
{
	struct efx_tc_ct_entry *conn = ptr;
	struct efx_nic *efx = arg;

	netif_err(efx, drv, efx->net_dev,
		  "tc ct_entry %lx still present at teardown\n",
		  conn->cookie);

	/* We can release the counter, but we can't remove the CT itself
	 * from hardware because the table meta is already gone.
	 */
	efx_tc_flower_release_counter(efx, conn->cnt);
	kfree(conn);
}

int efx_tc_init_conntrack(struct efx_nic *efx)
{
	int rc;

	rc = rhashtable_init(&efx->tc->ct_zone_ht, &efx_tc_ct_zone_ht_params);
	if (rc < 0)
		goto fail_ct_zone_ht;
	rc = rhashtable_init(&efx->tc->ct_ht, &efx_tc_ct_ht_params);
	if (rc < 0)
		goto fail_ct_ht;
	return 0;
fail_ct_ht:
	rhashtable_destroy(&efx->tc->ct_zone_ht);
fail_ct_zone_ht:
	return rc;
}

/* Only call this in init failure teardown.
 * Normal exit should fini instead as there may be entries in the table.
 */
void efx_tc_destroy_conntrack(struct efx_nic *efx)
{
	rhashtable_destroy(&efx->tc->ct_ht);
	rhashtable_destroy(&efx->tc->ct_zone_ht);
}

void efx_tc_fini_conntrack(struct efx_nic *efx)
{
	rhashtable_free_and_destroy(&efx->tc->ct_zone_ht, efx_tc_ct_zone_free, NULL);
	rhashtable_free_and_destroy(&efx->tc->ct_ht, efx_tc_ct_free, efx);
}

#define EFX_NF_TCP_FLAG(flg)	cpu_to_be16(be32_to_cpu(TCP_FLAG_##flg) >> 16)

static int efx_tc_ct_parse_match(struct efx_nic *efx, struct flow_rule *fr,
				 struct efx_tc_ct_entry *conn)
{
	struct flow_dissector *dissector = fr->match.dissector;
	unsigned char ipv = 0;
	bool tcp = false;

	if (flow_rule_match_key(fr, FLOW_DISSECTOR_KEY_CONTROL)) {
		struct flow_match_control fm;

		flow_rule_match_control(fr, &fm);
		if (IS_ALL_ONES(fm.mask->addr_type))
			switch (fm.key->addr_type) {
			case FLOW_DISSECTOR_KEY_IPV4_ADDRS:
				ipv = 4;
				break;
			case FLOW_DISSECTOR_KEY_IPV6_ADDRS:
				ipv = 6;
				break;
			default:
				break;
			}
	}

	if (!ipv) {
		netif_dbg(efx, drv, efx->net_dev,
			  "Conntrack missing ipv specification\n");
		return -EOPNOTSUPP;
	}

	if (dissector->used_keys &
	    ~(BIT_ULL(FLOW_DISSECTOR_KEY_CONTROL) |
	      BIT_ULL(FLOW_DISSECTOR_KEY_BASIC) |
	      BIT_ULL(FLOW_DISSECTOR_KEY_IPV4_ADDRS) |
	      BIT_ULL(FLOW_DISSECTOR_KEY_IPV6_ADDRS) |
	      BIT_ULL(FLOW_DISSECTOR_KEY_PORTS) |
	      BIT_ULL(FLOW_DISSECTOR_KEY_TCP) |
	      BIT_ULL(FLOW_DISSECTOR_KEY_META))) {
		netif_dbg(efx, drv, efx->net_dev,
			  "Unsupported conntrack keys %#llx\n",
			  dissector->used_keys);
		return -EOPNOTSUPP;
	}

	if (flow_rule_match_key(fr, FLOW_DISSECTOR_KEY_BASIC)) {
		struct flow_match_basic fm;

		flow_rule_match_basic(fr, &fm);
		if (!IS_ALL_ONES(fm.mask->n_proto)) {
			netif_dbg(efx, drv, efx->net_dev,
				  "Conntrack eth_proto is not exact-match; mask %04x\n",
				   ntohs(fm.mask->n_proto));
			return -EOPNOTSUPP;
		}
		conn->eth_proto = fm.key->n_proto;
		if (conn->eth_proto != (ipv == 4 ? htons(ETH_P_IP)
						 : htons(ETH_P_IPV6))) {
			netif_dbg(efx, drv, efx->net_dev,
				  "Conntrack eth_proto is not IPv%u, is %04x\n",
				   ipv, ntohs(conn->eth_proto));
			return -EOPNOTSUPP;
		}
		if (!IS_ALL_ONES(fm.mask->ip_proto)) {
			netif_dbg(efx, drv, efx->net_dev,
				  "Conntrack ip_proto is not exact-match; mask %02x\n",
				   fm.mask->ip_proto);
			return -EOPNOTSUPP;
		}
		conn->ip_proto = fm.key->ip_proto;
		switch (conn->ip_proto) {
		case IPPROTO_TCP:
			tcp = true;
			break;
		case IPPROTO_UDP:
			break;
		default:
			netif_dbg(efx, drv, efx->net_dev,
				  "Conntrack ip_proto not TCP or UDP, is %02x\n",
				   conn->ip_proto);
			return -EOPNOTSUPP;
		}
	} else {
		netif_dbg(efx, drv, efx->net_dev,
			  "Conntrack missing eth_proto, ip_proto\n");
		return -EOPNOTSUPP;
	}

	if (ipv == 4 && flow_rule_match_key(fr, FLOW_DISSECTOR_KEY_IPV4_ADDRS)) {
		struct flow_match_ipv4_addrs fm;

		flow_rule_match_ipv4_addrs(fr, &fm);
		if (!IS_ALL_ONES(fm.mask->src)) {
			netif_dbg(efx, drv, efx->net_dev,
				  "Conntrack ipv4.src is not exact-match; mask %08x\n",
				   ntohl(fm.mask->src));
			return -EOPNOTSUPP;
		}
		conn->src_ip = fm.key->src;
		if (!IS_ALL_ONES(fm.mask->dst)) {
			netif_dbg(efx, drv, efx->net_dev,
				  "Conntrack ipv4.dst is not exact-match; mask %08x\n",
				   ntohl(fm.mask->dst));
			return -EOPNOTSUPP;
		}
		conn->dst_ip = fm.key->dst;
	} else if (ipv == 6 && flow_rule_match_key(fr, FLOW_DISSECTOR_KEY_IPV6_ADDRS)) {
		struct flow_match_ipv6_addrs fm;

		flow_rule_match_ipv6_addrs(fr, &fm);
		if (!efx_ipv6_addr_all_ones(&fm.mask->src)) {
			netif_dbg(efx, drv, efx->net_dev,
				  "Conntrack ipv6.src is not exact-match; mask %pI6\n",
				   &fm.mask->src);
			return -EOPNOTSUPP;
		}
		conn->src_ip6 = fm.key->src;
		if (!efx_ipv6_addr_all_ones(&fm.mask->dst)) {
			netif_dbg(efx, drv, efx->net_dev,
				  "Conntrack ipv6.dst is not exact-match; mask %pI6\n",
				   &fm.mask->dst);
			return -EOPNOTSUPP;
		}
		conn->dst_ip6 = fm.key->dst;
	} else {
		netif_dbg(efx, drv, efx->net_dev,
			  "Conntrack missing IPv%u addrs\n", ipv);
		return -EOPNOTSUPP;
	}

	if (flow_rule_match_key(fr, FLOW_DISSECTOR_KEY_PORTS)) {
		struct flow_match_ports fm;

		flow_rule_match_ports(fr, &fm);
		if (!IS_ALL_ONES(fm.mask->src)) {
			netif_dbg(efx, drv, efx->net_dev,
				  "Conntrack ports.src is not exact-match; mask %04x\n",
				   ntohs(fm.mask->src));
			return -EOPNOTSUPP;
		}
		conn->l4_sport = fm.key->src;
		if (!IS_ALL_ONES(fm.mask->dst)) {
			netif_dbg(efx, drv, efx->net_dev,
				  "Conntrack ports.dst is not exact-match; mask %04x\n",
				   ntohs(fm.mask->dst));
			return -EOPNOTSUPP;
		}
		conn->l4_dport = fm.key->dst;
	} else {
		netif_dbg(efx, drv, efx->net_dev, "Conntrack missing L4 ports\n");
		return -EOPNOTSUPP;
	}

	if (flow_rule_match_key(fr, FLOW_DISSECTOR_KEY_TCP)) {
		__be16 tcp_interesting_flags;
		struct flow_match_tcp fm;

		if (!tcp) {
			netif_dbg(efx, drv, efx->net_dev,
				  "Conntrack matching on TCP keys but ipproto is not tcp\n");
			return -EOPNOTSUPP;
		}
		flow_rule_match_tcp(fr, &fm);
		tcp_interesting_flags = EFX_NF_TCP_FLAG(SYN) |
					EFX_NF_TCP_FLAG(RST) |
					EFX_NF_TCP_FLAG(FIN);
		/* If any of the tcp_interesting_flags is set, we always
		 * inhibit CT lookup in LHS (so SW can update CT table).
		 */
		if (fm.key->flags & tcp_interesting_flags) {
			netif_dbg(efx, drv, efx->net_dev,
				  "Unsupported conntrack tcp.flags %04x/%04x\n",
				   ntohs(fm.key->flags), ntohs(fm.mask->flags));
			return -EOPNOTSUPP;
		}
		/* Other TCP flags cannot be filtered at CT */
		if (fm.mask->flags & ~tcp_interesting_flags) {
			netif_dbg(efx, drv, efx->net_dev,
				  "Unsupported conntrack tcp.flags %04x/%04x\n",
				   ntohs(fm.key->flags), ntohs(fm.mask->flags));
			return -EOPNOTSUPP;
		}
	}

	return 0;
}

static int efx_tc_ct_replace(struct efx_tc_ct_zone *ct_zone,
			     struct flow_cls_offload *tc)
{
	struct flow_rule *fr = flow_cls_offload_flow_rule(tc);
	struct efx_tc_ct_entry *conn, *old;
	struct efx_nic *efx = ct_zone->efx;
	const struct flow_action_entry *fa;
	struct efx_tc_counter *cnt;
	int rc, i;

	if (WARN_ON(!efx->tc))
		return -ENETDOWN;
	if (WARN_ON(!efx->tc->up))
		return -ENETDOWN;

	conn = kzalloc(sizeof(*conn), GFP_USER);
	if (!conn)
		return -ENOMEM;
	conn->cookie = tc->cookie;
	old = rhashtable_lookup_get_insert_fast(&efx->tc->ct_ht,
						&conn->linkage,
						efx_tc_ct_ht_params);
	if (IS_ERR(old)) {
		rc = PTR_ERR(old);
		goto release;
	} else if (old) {
		netif_dbg(efx, drv, efx->net_dev,
			  "Already offloaded conntrack (cookie %lx)\n", tc->cookie);
		rc = -EEXIST;
		goto release;
	}

	/* Parse match */
	conn->zone = ct_zone;
	rc = efx_tc_ct_parse_match(efx, fr, conn);
	if (rc)
		goto release;

	/* Parse actions */
	flow_action_for_each(i, fa, &fr->action) {
		switch (fa->id) {
		case FLOW_ACTION_CT_METADATA:
			conn->mark = fa->ct_metadata.mark;
			if (memchr_inv(fa->ct_metadata.labels, 0, sizeof(fa->ct_metadata.labels))) {
				netif_dbg(efx, drv, efx->net_dev,
					  "Setting CT label not supported\n");
				rc = -EOPNOTSUPP;
				goto release;
			}
			break;
		default:
			netif_dbg(efx, drv, efx->net_dev,
				  "Unhandled action %u for conntrack\n", fa->id);
			rc = -EOPNOTSUPP;
			goto release;
		}
	}

	/* fill in defaults for unmangled values */
	conn->nat_ip = conn->dnat ? conn->dst_ip : conn->src_ip;
	conn->l4_natport = conn->dnat ? conn->l4_dport : conn->l4_sport;

	cnt = efx_tc_flower_allocate_counter(efx, EFX_TC_COUNTER_TYPE_CT);
	if (IS_ERR(cnt)) {
		rc = PTR_ERR(cnt);
		goto release;
	}
	conn->cnt = cnt;

	rc = efx_mae_insert_ct(efx, conn);
	if (rc) {
		netif_dbg(efx, drv, efx->net_dev,
			  "Failed to insert conntrack, %d\n", rc);
		goto release;
	}
	mutex_lock(&ct_zone->mutex);
	list_add_tail(&conn->list, &ct_zone->cts);
	mutex_unlock(&ct_zone->mutex);
	return 0;
release:
	if (conn->cnt)
		efx_tc_flower_release_counter(efx, conn->cnt);
	if (!old)
		rhashtable_remove_fast(&efx->tc->ct_ht, &conn->linkage,
				       efx_tc_ct_ht_params);
	kfree(conn);
	return rc;
}

/* Caller must follow with efx_tc_ct_remove_finish() after RCU grace period! */
static void efx_tc_ct_remove(struct efx_nic *efx, struct efx_tc_ct_entry *conn)
{
	int rc;

	/* Remove it from HW */
	rc = efx_mae_remove_ct(efx, conn);
	/* Delete it from SW */
	rhashtable_remove_fast(&efx->tc->ct_ht, &conn->linkage,
			       efx_tc_ct_ht_params);
	if (rc) {
		netif_err(efx, drv, efx->net_dev,
			  "Failed to remove conntrack %lx from hw, rc %d\n",
			  conn->cookie, rc);
	} else {
		netif_dbg(efx, drv, efx->net_dev, "Removed conntrack %lx\n",
			  conn->cookie);
	}
}

static void efx_tc_ct_remove_finish(struct efx_nic *efx, struct efx_tc_ct_entry *conn)
{
	/* Remove related CT counter.  This is delayed after the conn object we
	 * are working with has been successfully removed.  This protects the
	 * counter from being used-after-free inside efx_tc_ct_stats.
	 */
	efx_tc_flower_release_counter(efx, conn->cnt);
	kfree(conn);
}

static int efx_tc_ct_destroy(struct efx_tc_ct_zone *ct_zone,
			     struct flow_cls_offload *tc)
{
	struct efx_nic *efx = ct_zone->efx;
	struct efx_tc_ct_entry *conn;

	conn = rhashtable_lookup_fast(&efx->tc->ct_ht, &tc->cookie,
				      efx_tc_ct_ht_params);
	if (!conn) {
		netif_warn(efx, drv, efx->net_dev,
			   "Conntrack %lx not found to remove\n", tc->cookie);
		return -ENOENT;
	}

	mutex_lock(&ct_zone->mutex);
	list_del(&conn->list);
	efx_tc_ct_remove(efx, conn);
	mutex_unlock(&ct_zone->mutex);
	synchronize_rcu();
	efx_tc_ct_remove_finish(efx, conn);
	return 0;
}

static int efx_tc_ct_stats(struct efx_tc_ct_zone *ct_zone,
			   struct flow_cls_offload *tc)
{
	struct efx_nic *efx = ct_zone->efx;
	struct efx_tc_ct_entry *conn;
	struct efx_tc_counter *cnt;

	rcu_read_lock();
	conn = rhashtable_lookup_fast(&efx->tc->ct_ht, &tc->cookie,
				      efx_tc_ct_ht_params);
	if (!conn) {
		netif_warn(efx, drv, efx->net_dev,
			   "Conntrack %lx not found for stats\n", tc->cookie);
		rcu_read_unlock();
		return -ENOENT;
	}

	cnt = conn->cnt;
	spin_lock_bh(&cnt->lock);
	/* Report only last use */
	flow_stats_update(&tc->stats, 0, 0, 0, cnt->touched,
			  FLOW_ACTION_HW_STATS_DELAYED);
	spin_unlock_bh(&cnt->lock);
	rcu_read_unlock();

	return 0;
}

static int efx_tc_flow_block(enum tc_setup_type type, void *type_data,
			     void *cb_priv)
{
	struct flow_cls_offload *tcb = type_data;
	struct efx_tc_ct_zone *ct_zone = cb_priv;

	if (type != TC_SETUP_CLSFLOWER)
		return -EOPNOTSUPP;

	switch (tcb->command) {
	case FLOW_CLS_REPLACE:
		return efx_tc_ct_replace(ct_zone, tcb);
	case FLOW_CLS_DESTROY:
		return efx_tc_ct_destroy(ct_zone, tcb);
	case FLOW_CLS_STATS:
		return efx_tc_ct_stats(ct_zone, tcb);
	default:
		break;
	}

	return -EOPNOTSUPP;
}

struct efx_tc_ct_zone *efx_tc_ct_register_zone(struct efx_nic *efx, u16 zone,
					       struct nf_flowtable *ct_ft)
{
	struct efx_tc_ct_zone *ct_zone, *old;
	int rc;

	ct_zone = kzalloc(sizeof(*ct_zone), GFP_USER);
	if (!ct_zone)
		return ERR_PTR(-ENOMEM);
	ct_zone->zone = zone;
	old = rhashtable_lookup_get_insert_fast(&efx->tc->ct_zone_ht,
						&ct_zone->linkage,
						efx_tc_ct_zone_ht_params);
	if (old) {
		/* don't need our new entry */
		kfree(ct_zone);
		if (IS_ERR(old)) /* oh dear, it's actually an error */
			return ERR_CAST(old);
		if (!refcount_inc_not_zero(&old->ref))
			return ERR_PTR(-EAGAIN);
		/* existing entry found */
		WARN_ON_ONCE(old->nf_ft != ct_ft);
		netif_dbg(efx, drv, efx->net_dev,
			  "Found existing ct_zone for %u\n", zone);
		return old;
	}
	ct_zone->nf_ft = ct_ft;
	ct_zone->efx = efx;
	INIT_LIST_HEAD(&ct_zone->cts);
	mutex_init(&ct_zone->mutex);
	rc = nf_flow_table_offload_add_cb(ct_ft, efx_tc_flow_block, ct_zone);
	netif_dbg(efx, drv, efx->net_dev, "Adding new ct_zone for %u, rc %d\n",
		  zone, rc);
	if (rc < 0)
		goto fail;
	refcount_set(&ct_zone->ref, 1);
	return ct_zone;
fail:
	rhashtable_remove_fast(&efx->tc->ct_zone_ht, &ct_zone->linkage,
			       efx_tc_ct_zone_ht_params);
	kfree(ct_zone);
	return ERR_PTR(rc);
}

void efx_tc_ct_unregister_zone(struct efx_nic *efx,
			       struct efx_tc_ct_zone *ct_zone)
{
	struct efx_tc_ct_entry *conn, *next;

	if (!refcount_dec_and_test(&ct_zone->ref))
		return; /* still in use */
	nf_flow_table_offload_del_cb(ct_zone->nf_ft, efx_tc_flow_block, ct_zone);
	rhashtable_remove_fast(&efx->tc->ct_zone_ht, &ct_zone->linkage,
			       efx_tc_ct_zone_ht_params);
	mutex_lock(&ct_zone->mutex);
	list_for_each_entry(conn, &ct_zone->cts, list)
		efx_tc_ct_remove(efx, conn);
	synchronize_rcu();
	/* need to use _safe because efx_tc_ct_remove_finish() frees conn */
	list_for_each_entry_safe(conn, next, &ct_zone->cts, list)
		efx_tc_ct_remove_finish(efx, conn);
	mutex_unlock(&ct_zone->mutex);
	mutex_destroy(&ct_zone->mutex);
	netif_dbg(efx, drv, efx->net_dev, "Removed ct_zone for %u\n",
		  ct_zone->zone);
	kfree(ct_zone);
}