/* SPDX-License-Identifier: GPL-2.0-or-later */
/*
 * Cast6 Cipher 8-way parallel algorithm (AVX/x86_64)
 *
 * Copyright (C) 2012 Johannes Goetzfried
 *     <Johannes.Goetzfried@informatik.stud.uni-erlangen.de>
 *
 * Copyright © 2012-2013 Jussi Kivilinna <jussi.kivilinna@iki.fi>
 */

#include <linux/linkage.h>
#include <asm/frame.h>
#include "glue_helper-asm-avx.S"

.file "cast6-avx-x86_64-asm_64.S"

.extern cast_s1
.extern cast_s2
.extern cast_s3
.extern cast_s4

/* structure of crypto context */
#define km	0
#define kr	(12*4*4)

/* s-boxes */
#define s1	cast_s1
#define s2	cast_s2
#define s3	cast_s3
#define s4	cast_s4

/**********************************************************************
  8-way AVX cast6
 **********************************************************************/
#define CTX %r15

#define RA1 %xmm0
#define RB1 %xmm1
#define RC1 %xmm2
#define RD1 %xmm3

#define RA2 %xmm4
#define RB2 %xmm5
#define RC2 %xmm6
#define RD2 %xmm7

#define RX  %xmm8

#define RKM  %xmm9
#define RKR  %xmm10
#define RKRF %xmm11
#define RKRR %xmm12
#define R32  %xmm13
#define R1ST %xmm14

#define RTMP %xmm15

#define RID1  %rdi
#define RID1d %edi
#define RID2  %rsi
#define RID2d %esi

#define RGI1   %rdx
#define RGI1bl %dl
#define RGI1bh %dh
#define RGI2   %rcx
#define RGI2bl %cl
#define RGI2bh %ch

#define RGI3   %rax
#define RGI3bl %al
#define RGI3bh %ah
#define RGI4   %rbx
#define RGI4bl %bl
#define RGI4bh %bh

#define RFS1  %r8
#define RFS1d %r8d
#define RFS2  %r9
#define RFS2d %r9d
#define RFS3  %r10
#define RFS3d %r10d


#define lookup_32bit(src, dst, op1, op2, op3, interleave_op, il_reg) \
	movzbl		src ## bh,     RID1d;    \
	leaq		s1(%rip),      RID2;     \
	movl		(RID2,RID1,4), dst ## d; \
	movzbl		src ## bl,     RID2d;    \
	leaq		s2(%rip),      RID1;     \
	op1		(RID1,RID2,4), dst ## d; \
	shrq $16,	src;                     \
	movzbl		src ## bh,     RID1d;    \
	leaq		s3(%rip),      RID2;     \
	op2		(RID2,RID1,4), dst ## d; \
	movzbl		src ## bl,     RID2d;    \
	interleave_op(il_reg);			 \
	leaq		s4(%rip),      RID1;     \
	op3		(RID1,RID2,4), dst ## d;

#define dummy(d) /* do nothing */

#define shr_next(reg) \
	shrq $16,	reg;

#define F_head(a, x, gi1, gi2, op0) \
	op0	a,	RKM,  x;                 \
	vpslld	RKRF,	x,    RTMP;              \
	vpsrld	RKRR,	x,    x;                 \
	vpor	RTMP,	x,    x;                 \
	\
	vmovq		x,    gi1;               \
	vpextrq $1,	x,    gi2;

#define F_tail(a, x, gi1, gi2, op1, op2, op3) \
	lookup_32bit(##gi1, RFS1, op1, op2, op3, shr_next, ##gi1); \
	lookup_32bit(##gi2, RFS3, op1, op2, op3, shr_next, ##gi2); \
	\
	lookup_32bit(##gi1, RFS2, op1, op2, op3, dummy, none);     \
	shlq $32,	RFS2;                                      \
	orq		RFS1, RFS2;                                \
	lookup_32bit(##gi2, RFS1, op1, op2, op3, dummy, none);     \
	shlq $32,	RFS1;                                      \
	orq		RFS1, RFS3;                                \
	\
	vmovq		RFS2, x;                                   \
	vpinsrq $1,	RFS3, x, x;

#define F_2(a1, b1, a2, b2, op0, op1, op2, op3) \
	F_head(b1, RX, RGI1, RGI2, op0);              \
	F_head(b2, RX, RGI3, RGI4, op0);              \
	\
	F_tail(b1, RX, RGI1, RGI2, op1, op2, op3);    \
	F_tail(b2, RTMP, RGI3, RGI4, op1, op2, op3);  \
	\
	vpxor		a1, RX,   a1;                 \
	vpxor		a2, RTMP, a2;

#define F1_2(a1, b1, a2, b2) \
	F_2(a1, b1, a2, b2, vpaddd, xorl, subl, addl)
#define F2_2(a1, b1, a2, b2) \
	F_2(a1, b1, a2, b2, vpxor, subl, addl, xorl)
#define F3_2(a1, b1, a2, b2) \
	F_2(a1, b1, a2, b2, vpsubd, addl, xorl, subl)

#define qop(in, out, f) \
	F ## f ## _2(out ## 1, in ## 1, out ## 2, in ## 2);

#define get_round_keys(nn) \
	vbroadcastss	(km+(4*(nn)))(CTX), RKM;        \
	vpand		R1ST,               RKR,  RKRF; \
	vpsubq		RKRF,               R32,  RKRR; \
	vpsrldq $1,	RKR,                RKR;

#define Q(n) \
	get_round_keys(4*n+0); \
	qop(RD, RC, 1);        \
	\
	get_round_keys(4*n+1); \
	qop(RC, RB, 2);        \
	\
	get_round_keys(4*n+2); \
	qop(RB, RA, 3);        \
	\
	get_round_keys(4*n+3); \
	qop(RA, RD, 1);

#define QBAR(n) \
	get_round_keys(4*n+3); \
	qop(RA, RD, 1);        \
	\
	get_round_keys(4*n+2); \
	qop(RB, RA, 3);        \
	\
	get_round_keys(4*n+1); \
	qop(RC, RB, 2);        \
	\
	get_round_keys(4*n+0); \
	qop(RD, RC, 1);

#define shuffle(mask) \
	vpshufb		mask(%rip),            RKR, RKR;

#define preload_rkr(n, do_mask, mask) \
	vbroadcastss	.L16_mask(%rip),          RKR;      \
	/* add 16-bit rotation to key rotations (mod 32) */ \
	vpxor		(kr+n*16)(CTX),           RKR, RKR; \
	do_mask(mask);

#define transpose_4x4(x0, x1, x2, x3, t0, t1, t2) \
	vpunpckldq		x1, x0, t0; \
	vpunpckhdq		x1, x0, t2; \
	vpunpckldq		x3, x2, t1; \
	vpunpckhdq		x3, x2, x3; \
	\
	vpunpcklqdq		t1, t0, x0; \
	vpunpckhqdq		t1, t0, x1; \
	vpunpcklqdq		x3, t2, x2; \
	vpunpckhqdq		x3, t2, x3;

#define inpack_blocks(x0, x1, x2, x3, t0, t1, t2, rmask) \
	vpshufb rmask, x0,	x0; \
	vpshufb rmask, x1,	x1; \
	vpshufb rmask, x2,	x2; \
	vpshufb rmask, x3,	x3; \
	\
	transpose_4x4(x0, x1, x2, x3, t0, t1, t2)

#define outunpack_blocks(x0, x1, x2, x3, t0, t1, t2, rmask) \
	transpose_4x4(x0, x1, x2, x3, t0, t1, t2) \
	\
	vpshufb rmask,		x0, x0;       \
	vpshufb rmask,		x1, x1;       \
	vpshufb rmask,		x2, x2;       \
	vpshufb rmask,		x3, x3;

.section	.rodata.cst16, "aM", @progbits, 16
.align 16
.Lbswap_mask:
	.byte 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12
.Lbswap128_mask:
	.byte 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
.Lrkr_enc_Q_Q_QBAR_QBAR:
	.byte 0, 1, 2, 3, 4, 5, 6, 7, 11, 10, 9, 8, 15, 14, 13, 12
.Lrkr_enc_QBAR_QBAR_QBAR_QBAR:
	.byte 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12
.Lrkr_dec_Q_Q_Q_Q:
	.byte 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3
.Lrkr_dec_Q_Q_QBAR_QBAR:
	.byte 12, 13, 14, 15, 8, 9, 10, 11, 7, 6, 5, 4, 3, 2, 1, 0
.Lrkr_dec_QBAR_QBAR_QBAR_QBAR:
	.byte 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0

.section	.rodata.cst4.L16_mask, "aM", @progbits, 4
.align 4
.L16_mask:
	.byte 16, 16, 16, 16

.section	.rodata.cst4.L32_mask, "aM", @progbits, 4
.align 4
.L32_mask:
	.byte 32, 0, 0, 0

.section	.rodata.cst4.first_mask, "aM", @progbits, 4
.align 4
.Lfirst_mask:
	.byte 0x1f, 0, 0, 0

.text

.align 8
SYM_FUNC_START_LOCAL(__cast6_enc_blk8)
	/* input:
	 *	%rdi: ctx
	 *	RA1, RB1, RC1, RD1, RA2, RB2, RC2, RD2: blocks
	 * output:
	 *	RA1, RB1, RC1, RD1, RA2, RB2, RC2, RD2: encrypted blocks
	 */

	pushq %r15;
	pushq %rbx;

	movq %rdi, CTX;

	vmovdqa .Lbswap_mask(%rip), RKM;
	vmovd .Lfirst_mask(%rip), R1ST;
	vmovd .L32_mask(%rip), R32;

	inpack_blocks(RA1, RB1, RC1, RD1, RTMP, RX, RKRF, RKM);
	inpack_blocks(RA2, RB2, RC2, RD2, RTMP, RX, RKRF, RKM);

	preload_rkr(0, dummy, none);
	Q(0);
	Q(1);
	Q(2);
	Q(3);
	preload_rkr(1, shuffle, .Lrkr_enc_Q_Q_QBAR_QBAR);
	Q(4);
	Q(5);
	QBAR(6);
	QBAR(7);
	preload_rkr(2, shuffle, .Lrkr_enc_QBAR_QBAR_QBAR_QBAR);
	QBAR(8);
	QBAR(9);
	QBAR(10);
	QBAR(11);

	popq %rbx;
	popq %r15;

	vmovdqa .Lbswap_mask(%rip), RKM;

	outunpack_blocks(RA1, RB1, RC1, RD1, RTMP, RX, RKRF, RKM);
	outunpack_blocks(RA2, RB2, RC2, RD2, RTMP, RX, RKRF, RKM);

	RET;
SYM_FUNC_END(__cast6_enc_blk8)

.align 8
SYM_FUNC_START_LOCAL(__cast6_dec_blk8)
	/* input:
	 *	%rdi: ctx
	 *	RA1, RB1, RC1, RD1, RA2, RB2, RC2, RD2: encrypted blocks
	 * output:
	 *	RA1, RB1, RC1, RD1, RA2, RB2, RC2, RD2: decrypted blocks
	 */

	pushq %r15;
	pushq %rbx;

	movq %rdi, CTX;

	vmovdqa .Lbswap_mask(%rip), RKM;
	vmovd .Lfirst_mask(%rip), R1ST;
	vmovd .L32_mask(%rip), R32;

	inpack_blocks(RA1, RB1, RC1, RD1, RTMP, RX, RKRF, RKM);
	inpack_blocks(RA2, RB2, RC2, RD2, RTMP, RX, RKRF, RKM);

	preload_rkr(2, shuffle, .Lrkr_dec_Q_Q_Q_Q);
	Q(11);
	Q(10);
	Q(9);
	Q(8);
	preload_rkr(1, shuffle, .Lrkr_dec_Q_Q_QBAR_QBAR);
	Q(7);
	Q(6);
	QBAR(5);
	QBAR(4);
	preload_rkr(0, shuffle, .Lrkr_dec_QBAR_QBAR_QBAR_QBAR);
	QBAR(3);
	QBAR(2);
	QBAR(1);
	QBAR(0);

	popq %rbx;
	popq %r15;

	vmovdqa .Lbswap_mask(%rip), RKM;
	outunpack_blocks(RA1, RB1, RC1, RD1, RTMP, RX, RKRF, RKM);
	outunpack_blocks(RA2, RB2, RC2, RD2, RTMP, RX, RKRF, RKM);

	RET;
SYM_FUNC_END(__cast6_dec_blk8)

SYM_FUNC_START(cast6_ecb_enc_8way)
	/* input:
	 *	%rdi: ctx
	 *	%rsi: dst
	 *	%rdx: src
	 */
	FRAME_BEGIN
	pushq %r15;

	movq %rdi, CTX;
	movq %rsi, %r11;

	load_8way(%rdx, RA1, RB1, RC1, RD1, RA2, RB2, RC2, RD2);

	call __cast6_enc_blk8;

	store_8way(%r11, RA1, RB1, RC1, RD1, RA2, RB2, RC2, RD2);

	popq %r15;
	FRAME_END
	RET;
SYM_FUNC_END(cast6_ecb_enc_8way)

SYM_FUNC_START(cast6_ecb_dec_8way)
	/* input:
	 *	%rdi: ctx
	 *	%rsi: dst
	 *	%rdx: src
	 */
	FRAME_BEGIN
	pushq %r15;

	movq %rdi, CTX;
	movq %rsi, %r11;

	load_8way(%rdx, RA1, RB1, RC1, RD1, RA2, RB2, RC2, RD2);

	call __cast6_dec_blk8;

	store_8way(%r11, RA1, RB1, RC1, RD1, RA2, RB2, RC2, RD2);

	popq %r15;
	FRAME_END
	RET;
SYM_FUNC_END(cast6_ecb_dec_8way)

SYM_FUNC_START(cast6_cbc_dec_8way)
	/* input:
	 *	%rdi: ctx
	 *	%rsi: dst
	 *	%rdx: src
	 */
	FRAME_BEGIN
	pushq %r12;
	pushq %r15;

	movq %rdi, CTX;
	movq %rsi, %r11;
	movq %rdx, %r12;

	load_8way(%rdx, RA1, RB1, RC1, RD1, RA2, RB2, RC2, RD2);

	call __cast6_dec_blk8;

	store_cbc_8way(%r12, %r11, RA1, RB1, RC1, RD1, RA2, RB2, RC2, RD2);

	popq %r15;
	popq %r12;
	FRAME_END
	RET;
SYM_FUNC_END