#include <linux/types.h>
#include <linux/socket.h>
#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/skbuff.h>
#include <linux/inet.h>
#include <linux/fdtable.h>
#include <linux/rhashtable.h>
#include <net/sock.h>
#include <net/genetlink.h>
#include <net/netns/generic.h>
#include <kunit/visibility.h>
#include <uapi/linux/handshake.h>
#include "handshake.h"
#include <trace/events/handshake.h>
static struct rhashtable handshake_rhashtbl ____cacheline_aligned_in_smp;
static const struct rhashtable_params handshake_rhash_params = {
.key_len = sizeof_field(struct handshake_req, hr_sk),
.key_offset = offsetof(struct handshake_req, hr_sk),
.head_offset = offsetof(struct handshake_req, hr_rhash),
.automatic_shrinking = true,
};
int handshake_req_hash_init(void)
{
return rhashtable_init(&handshake_rhashtbl, &handshake_rhash_params);
}
void handshake_req_hash_destroy(void)
{
rhashtable_destroy(&handshake_rhashtbl);
}
struct handshake_req *handshake_req_hash_lookup(struct sock *sk)
{
return rhashtable_lookup_fast(&handshake_rhashtbl, &sk,
handshake_rhash_params);
}
EXPORT_SYMBOL_IF_KUNIT(handshake_req_hash_lookup);
static bool handshake_req_hash_add(struct handshake_req *req)
{
int ret;
ret = rhashtable_lookup_insert_fast(&handshake_rhashtbl,
&req->hr_rhash,
handshake_rhash_params);
return ret == 0;
}
static void handshake_req_destroy(struct handshake_req *req)
{
if (req->hr_proto->hp_destroy)
req->hr_proto->hp_destroy(req);
rhashtable_remove_fast(&handshake_rhashtbl, &req->hr_rhash,
handshake_rhash_params);
kfree(req);
}
static void handshake_sk_destruct(struct sock *sk)
{
void (*sk_destruct)(struct sock *sk);
struct handshake_req *req;
req = handshake_req_hash_lookup(sk);
if (!req)
return;
trace_handshake_destruct(sock_net(sk), req, sk);
sk_destruct = req->hr_odestruct;
handshake_req_destroy(req);
if (sk_destruct)
sk_destruct(sk);
}
struct handshake_req *handshake_req_alloc(const struct handshake_proto *proto,
gfp_t flags)
{
struct handshake_req *req;
if (!proto)
return NULL;
if (proto->hp_handler_class <= HANDSHAKE_HANDLER_CLASS_NONE)
return NULL;
if (proto->hp_handler_class >= HANDSHAKE_HANDLER_CLASS_MAX)
return NULL;
if (!proto->hp_accept || !proto->hp_done)
return NULL;
req = kzalloc(struct_size(req, hr_priv, proto->hp_privsize), flags);
if (!req)
return NULL;
INIT_LIST_HEAD(&req->hr_list);
req->hr_proto = proto;
return req;
}
EXPORT_SYMBOL(handshake_req_alloc);
void *handshake_req_private(struct handshake_req *req)
{
return (void *)&req->hr_priv;
}
EXPORT_SYMBOL(handshake_req_private);
static bool __add_pending_locked(struct handshake_net *hn,
struct handshake_req *req)
{
if (WARN_ON_ONCE(!list_empty(&req->hr_list)))
return false;
hn->hn_pending++;
list_add_tail(&req->hr_list, &hn->hn_requests);
return true;
}
static void __remove_pending_locked(struct handshake_net *hn,
struct handshake_req *req)
{
hn->hn_pending--;
list_del_init(&req->hr_list);
}
static bool remove_pending(struct handshake_net *hn, struct handshake_req *req)
{
bool ret = false;
spin_lock(&hn->hn_lock);
if (!list_empty(&req->hr_list)) {
__remove_pending_locked(hn, req);
ret = true;
}
spin_unlock(&hn->hn_lock);
return ret;
}
struct handshake_req *handshake_req_next(struct handshake_net *hn, int class)
{
struct handshake_req *req, *pos;
req = NULL;
spin_lock(&hn->hn_lock);
list_for_each_entry(pos, &hn->hn_requests, hr_list) {
if (pos->hr_proto->hp_handler_class != class)
continue;
__remove_pending_locked(hn, pos);
req = pos;
break;
}
spin_unlock(&hn->hn_lock);
return req;
}
EXPORT_SYMBOL_IF_KUNIT(handshake_req_next);
int handshake_req_submit(struct socket *sock, struct handshake_req *req,
gfp_t flags)
{
struct handshake_net *hn;
struct net *net;
int ret;
if (!sock || !req || !sock->file) {
kfree(req);
return -EINVAL;
}
req->hr_sk = sock->sk;
if (!req->hr_sk) {
kfree(req);
return -EINVAL;
}
req->hr_odestruct = req->hr_sk->sk_destruct;
req->hr_sk->sk_destruct = handshake_sk_destruct;
ret = -EOPNOTSUPP;
net = sock_net(req->hr_sk);
hn = handshake_pernet(net);
if (!hn)
goto out_err;
ret = -EAGAIN;
if (READ_ONCE(hn->hn_pending) >= hn->hn_pending_max)
goto out_err;
spin_lock(&hn->hn_lock);
ret = -EOPNOTSUPP;
if (test_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags))
goto out_unlock;
ret = -EBUSY;
if (!handshake_req_hash_add(req))
goto out_unlock;
if (!__add_pending_locked(hn, req))
goto out_unlock;
spin_unlock(&hn->hn_lock);
ret = handshake_genl_notify(net, req->hr_proto, flags);
if (ret) {
trace_handshake_notify_err(net, req, req->hr_sk, ret);
if (remove_pending(hn, req))
goto out_err;
}
sock_hold(req->hr_sk);
trace_handshake_submit(net, req, req->hr_sk);
return 0;
out_unlock:
spin_unlock(&hn->hn_lock);
out_err:
trace_handshake_submit_err(net, req, req->hr_sk, ret);
handshake_req_destroy(req);
return ret;
}
EXPORT_SYMBOL(handshake_req_submit);
void handshake_complete(struct handshake_req *req, unsigned int status,
struct genl_info *info)
{
struct sock *sk = req->hr_sk;
struct net *net = sock_net(sk);
if (!test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) {
trace_handshake_complete(net, req, sk, status);
req->hr_proto->hp_done(req, status, info);
sock_put(sk);
}
}
EXPORT_SYMBOL_IF_KUNIT(handshake_complete);
bool handshake_req_cancel(struct sock *sk)
{
struct handshake_req *req;
struct handshake_net *hn;
struct net *net;
net = sock_net(sk);
req = handshake_req_hash_lookup(sk);
if (!req) {
trace_handshake_cancel_none(net, req, sk);
return false;
}
hn = handshake_pernet(net);
if (hn && remove_pending(hn, req)) {
goto out_true;
}
if (test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) {
trace_handshake_cancel_busy(net, req, sk);
return false;
}
out_true:
trace_handshake_cancel(net, req, sk);
sock_put(sk);
return true;
}
EXPORT_SYMBOL