// SPDX-License-Identifier: GPL-2.0-only
/*
 * Copyright (c) 2021, 2022 Oracle.  All rights reserved.
 *
 * The AUTH_TLS credential is used only to probe a remote peer
 * for RPC-over-TLS support.
 */

#include <linux/types.h>
#include <linux/module.h>
#include <linux/sunrpc/clnt.h>

static const char *starttls_token = "STARTTLS";
static const size_t starttls_len = 8;

static struct rpc_auth tls_auth;
static struct rpc_cred tls_cred;

static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
			     const void *obj)
{
}

static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
			    void *obj)
{
	return 0;
}

static const struct rpc_procinfo rpcproc_tls_probe = {
	.p_encode	= tls_encode_probe,
	.p_decode	= tls_decode_probe,
};

static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data)
{
	task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT;
	rpc_call_start(task);
}

static void rpc_tls_probe_call_done(struct rpc_task *task, void *data)
{
}

static const struct rpc_call_ops rpc_tls_probe_ops = {
	.rpc_call_prepare	= rpc_tls_probe_call_prepare,
	.rpc_call_done		= rpc_tls_probe_call_done,
};

static int tls_probe(struct rpc_clnt *clnt)
{
	struct rpc_message msg = {
		.rpc_proc	= &rpcproc_tls_probe,
	};
	struct rpc_task_setup task_setup_data = {
		.rpc_client	= clnt,
		.rpc_message	= &msg,
		.rpc_op_cred	= &tls_cred,
		.callback_ops	= &rpc_tls_probe_ops,
		.flags		= RPC_TASK_SOFT | RPC_TASK_SOFTCONN,
	};
	struct rpc_task	*task;
	int status;

	task = rpc_run_task(&task_setup_data);
	if (IS_ERR(task))
		return PTR_ERR(task);
	status = task->tk_status;
	rpc_put_task(task);
	return status;
}

static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args,
				   struct rpc_clnt *clnt)
{
	refcount_inc(&tls_auth.au_count);
	return &tls_auth;
}

static void tls_destroy(struct rpc_auth *auth)
{
}

static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth,
					struct auth_cred *acred, int flags)
{
	return get_rpccred(&tls_cred);
}

static void tls_destroy_cred(struct rpc_cred *cred)
{
}

static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags)
{
	return 1;
}

static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr)
{
	__be32 *p;

	p = xdr_reserve_space(xdr, 4 * XDR_UNIT);
	if (!p)
		return -EMSGSIZE;
	/* Credential */
	*p++ = rpc_auth_tls;
	*p++ = xdr_zero;
	/* Verifier */
	*p++ = rpc_auth_null;
	*p   = xdr_zero;
	return 0;
}

static int tls_refresh(struct rpc_task *task)
{
	set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags);
	return 0;
}

static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr)
{
	__be32 *p;
	void *str;

	p = xdr_inline_decode(xdr, XDR_UNIT);
	if (!p)
		return -EIO;
	if (*p != rpc_auth_null)
		return -EIO;
	if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len)
		return -EPROTONOSUPPORT;
	if (memcmp(str, starttls_token, starttls_len))
		return -EPROTONOSUPPORT;
	return 0;
}

const struct rpc_authops authtls_ops = {
	.owner		= THIS_MODULE,
	.au_flavor	= RPC_AUTH_TLS,
	.au_name	= "NULL",
	.create		= tls_create,
	.destroy	= tls_destroy,
	.lookup_cred	= tls_lookup_cred,
	.ping		= tls_probe,
};

static struct rpc_auth tls_auth = {
	.au_cslack	= NUL_CALLSLACK,
	.au_rslack	= NUL_REPLYSLACK,
	.au_verfsize	= NUL_REPLYSLACK,
	.au_ralign	= NUL_REPLYSLACK,
	.au_ops		= &authtls_ops,
	.au_flavor	= RPC_AUTH_TLS,
	.au_count	= REFCOUNT_INIT(1),
};

static const struct rpc_credops tls_credops = {
	.cr_name	= "AUTH_TLS",
	.crdestroy	= tls_destroy_cred,
	.crmatch	= tls_match,
	.crmarshal	= tls_marshal,
	.crwrap_req	= rpcauth_wrap_req_encode,
	.crrefresh	= tls_refresh,
	.crvalidate	= tls_validate,
	.crunwrap_resp	= rpcauth_unwrap_resp_decode,
};

static struct rpc_cred tls_cred = {
	.cr_lru		= LIST_HEAD_INIT(tls_cred.cr_lru),
	.cr_auth	= &tls_auth,
	.cr_ops		= &tls_credops,
	.cr_count	= REFCOUNT_INIT(2),
	.cr_flags	= 1UL << RPCAUTH_CRED_UPTODATE,
}