/* SPDX-License-Identifier: GPL-2.0-only */
/*
 * linux/arch/arm64/crypto/aes-modes.S - chaining mode wrappers for AES
 *
 * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
 */

/* included by aes-ce.S and aes-neon.S */

	.text
	.align		4

#ifndef MAX_STRIDE
#define MAX_STRIDE	4
#endif

#if MAX_STRIDE == 4
#define ST4(x...) x
#define ST5(x...)
#else
#define ST4(x...)
#define ST5(x...) x
#endif

SYM_FUNC_START_LOCAL(aes_encrypt_block4x)
	encrypt_block4x	v0, v1, v2, v3, w3, x2, x8, w7
	ret
SYM_FUNC_END(aes_encrypt_block4x)

SYM_FUNC_START_LOCAL(aes_decrypt_block4x)
	decrypt_block4x	v0, v1, v2, v3, w3, x2, x8, w7
	ret
SYM_FUNC_END(aes_decrypt_block4x)

#if MAX_STRIDE == 5
SYM_FUNC_START_LOCAL(aes_encrypt_block5x)
	encrypt_block5x	v0, v1, v2, v3, v4, w3, x2, x8, w7
	ret
SYM_FUNC_END(aes_encrypt_block5x)

SYM_FUNC_START_LOCAL(aes_decrypt_block5x)
	decrypt_block5x	v0, v1, v2, v3, v4, w3, x2, x8, w7
	ret
SYM_FUNC_END(aes_decrypt_block5x)
#endif

	/*
	 * aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
	 *		   int blocks)
	 * aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
	 *		   int blocks)
	 */

AES_FUNC_START(aes_ecb_encrypt)
	frame_push	0

	enc_prepare	w3, x2, x5

.LecbencloopNx:
	subs		w4, w4, #MAX_STRIDE
	bmi		.Lecbenc1x
	ld1		{v0.16b-v3.16b}, [x1], #64	/* get 4 pt blocks */
ST4(	bl		aes_encrypt_block4x		)
ST5(	ld1		{v4.16b}, [x1], #16		)
ST5(	bl		aes_encrypt_block5x		)
	st1		{v0.16b-v3.16b}, [x0], #64
ST5(	st1		{v4.16b}, [x0], #16		)
	b		.LecbencloopNx
.Lecbenc1x:
	adds		w4, w4, #MAX_STRIDE
	beq		.Lecbencout
.Lecbencloop:
	ld1		{v0.16b}, [x1], #16		/* get next pt block */
	encrypt_block	v0, w3, x2, x5, w6
	st1		{v0.16b}, [x0], #16
	subs		w4, w4, #1
	bne		.Lecbencloop
.Lecbencout:
	frame_pop
	ret
AES_FUNC_END(aes_ecb_encrypt)


AES_FUNC_START(aes_ecb_decrypt)
	frame_push	0

	dec_prepare	w3, x2, x5

.LecbdecloopNx:
	subs		w4, w4, #MAX_STRIDE
	bmi		.Lecbdec1x
	ld1		{v0.16b-v3.16b}, [x1], #64	/* get 4 ct blocks */
ST4(	bl		aes_decrypt_block4x		)
ST5(	ld1		{v4.16b}, [x1], #16		)
ST5(	bl		aes_decrypt_block5x		)
	st1		{v0.16b-v3.16b}, [x0], #64
ST5(	st1		{v4.16b}, [x0], #16		)
	b		.LecbdecloopNx
.Lecbdec1x:
	adds		w4, w4, #MAX_STRIDE
	beq		.Lecbdecout
.Lecbdecloop:
	ld1		{v0.16b}, [x1], #16		/* get next ct block */
	decrypt_block	v0, w3, x2, x5, w6
	st1		{v0.16b}, [x0], #16
	subs		w4, w4, #1
	bne		.Lecbdecloop
.Lecbdecout:
	frame_pop
	ret
AES_FUNC_END(aes_ecb_decrypt)


	/*
	 * aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
	 *		   int blocks, u8 iv[])
	 * aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
	 *		   int blocks, u8 iv[])
	 * aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
	 *			 int rounds, int blocks, u8 iv[],
	 *			 u32 const rk2[]);
	 * aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
	 *			 int rounds, int blocks, u8 iv[],
	 *			 u32 const rk2[]);
	 */

AES_FUNC_START(aes_essiv_cbc_encrypt)
	ld1		{v4.16b}, [x5]			/* get iv */

	mov		w8, #14				/* AES-256: 14 rounds */
	enc_prepare	w8, x6, x7
	encrypt_block	v4, w8, x6, x7, w9
	enc_switch_key	w3, x2, x6
	b		.Lcbcencloop4x

AES_FUNC_START(aes_cbc_encrypt)
	ld1		{v4.16b}, [x5]			/* get iv */
	enc_prepare	w3, x2, x6

.Lcbcencloop4x:
	subs		w4, w4, #4
	bmi		.Lcbcenc1x
	ld1		{v0.16b-v3.16b}, [x1], #64	/* get 4 pt blocks */
	eor		v0.16b, v0.16b, v4.16b		/* ..and xor with iv */
	encrypt_block	v0, w3, x2, x6, w7
	eor		v1.16b, v1.16b, v0.16b
	encrypt_block	v1, w3, x2, x6, w7
	eor		v2.16b, v2.16b, v1.16b
	encrypt_block	v2, w3, x2, x6, w7
	eor		v3.16b, v3.16b, v2.16b
	encrypt_block	v3, w3, x2, x6, w7
	st1		{v0.16b-v3.16b}, [x0], #64
	mov		v4.16b, v3.16b
	b		.Lcbcencloop4x
.Lcbcenc1x:
	adds		w4, w4, #4
	beq		.Lcbcencout
.Lcbcencloop:
	ld1		{v0.16b}, [x1], #16		/* get next pt block */
	eor		v4.16b, v4.16b, v0.16b		/* ..and xor with iv */
	encrypt_block	v4, w3, x2, x6, w7
	st1		{v4.16b}, [x0], #16
	subs		w4, w4, #1
	bne		.Lcbcencloop
.Lcbcencout:
	st1		{v4.16b}, [x5]			/* return iv */
	ret
AES_FUNC_END(aes_cbc_encrypt)
AES_FUNC_END(aes_essiv_cbc_encrypt)

AES_FUNC_START(aes_essiv_cbc_decrypt)
	ld1		{cbciv.16b}, [x5]		/* get iv */

	mov		w8, #14				/* AES-256: 14 rounds */
	enc_prepare	w8, x6, x7
	encrypt_block	cbciv, w8, x6, x7, w9
	b		.Lessivcbcdecstart

AES_FUNC_START(aes_cbc_decrypt)
	ld1		{cbciv.16b}, [x5]		/* get iv */
.Lessivcbcdecstart:
	frame_push	0
	dec_prepare	w3, x2, x6

.LcbcdecloopNx:
	subs		w4, w4, #MAX_STRIDE
	bmi		.Lcbcdec1x
	ld1		{v0.16b-v3.16b}, [x1], #64	/* get 4 ct blocks */
#if MAX_STRIDE == 5
	ld1		{v4.16b}, [x1], #16		/* get 1 ct block */
	mov		v5.16b, v0.16b
	mov		v6.16b, v1.16b
	mov		v7.16b, v2.16b
	bl		aes_decrypt_block5x
	sub		x1, x1, #32
	eor		v0.16b, v0.16b, cbciv.16b
	eor		v1.16b, v1.16b, v5.16b
	ld1		{v5.16b}, [x1], #16		/* reload 1 ct block */
	ld1		{cbciv.16b}, [x1], #16		/* reload 1 ct block */
	eor		v2.16b, v2.16b, v6.16b
	eor		v3.16b, v3.16b, v7.16b
	eor		v4.16b, v4.16b, v5.16b
#else
	mov		v4.16b, v0.16b
	mov		v5.16b, v1.16b
	mov		v6.16b, v2.16b
	bl		aes_decrypt_block4x
	sub		x1, x1, #16
	eor		v0.16b, v0.16b, cbciv.16b
	eor		v1.16b, v1.16b, v4.16b
	ld1		{cbciv.16b}, [x1], #16		/* reload 1 ct block */
	eor		v2.16b, v2.16b, v5.16b
	eor		v3.16b, v3.16b, v6.16b
#endif
	st1		{v0.16b-v3.16b}, [x0], #64
ST5(	st1		{v4.16b}, [x0], #16		)
	b		.LcbcdecloopNx
.Lcbcdec1x:
	adds		w4, w4, #MAX_STRIDE
	beq		.Lcbcdecout
.Lcbcdecloop:
	ld1		{v1.16b}, [x1], #16		/* get next ct block */
	mov		v0.16b, v1.16b			/* ...and copy to v0 */
	decrypt_block	v0, w3, x2, x6, w7
	eor		v0.16b, v0.16b, cbciv.16b	/* xor with iv => pt */
	mov		cbciv.16b, v1.16b		/* ct is next iv */
	st1		{v0.16b}, [x0], #16
	subs		w4, w4, #1
	bne		.Lcbcdecloop
.Lcbcdecout:
	st1		{cbciv.16b}, [x5]		/* return iv */
	frame_pop
	ret
AES_FUNC_END(aes_cbc_decrypt)
AES_FUNC_END(aes_essiv_cbc_decrypt)


	/*
	 * aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
	 *		       int rounds, int bytes, u8 const iv[])
	 * aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
	 *		       int rounds, int bytes, u8 const iv[])
	 */

AES_FUNC_START(aes_cbc_cts_encrypt)
	adr_l		x8, .Lcts_permute_table
	sub		x4, x4, #16
	add		x9, x8, #32
	add		x8, x8, x4
	sub		x9, x9, x4
	ld1		{v3.16b}, [x8]
	ld1		{v4.16b}, [x9]

	ld1		{v0.16b}, [x1], x4		/* overlapping loads */
	ld1		{v1.16b}, [x1]

	ld1		{v5.16b}, [x5]			/* get iv */
	enc_prepare	w3, x2, x6

	eor		v0.16b, v0.16b, v5.16b		/* xor with iv */
	tbl		v1.16b, {v1.16b}, v4.16b
	encrypt_block	v0, w3, x2, x6, w7

	eor		v1.16b, v1.16b, v0.16b
	tbl		v0.16b, {v0.16b}, v3.16b
	encrypt_block	v1, w3, x2, x6, w7

	add		x4, x0, x4
	st1		{v0.16b}, [x4]			/* overlapping stores */
	st1		{v1.16b}, [x0]
	ret
AES_FUNC_END(aes_cbc_cts_encrypt)

AES_FUNC_START(aes_cbc_cts_decrypt)
	adr_l		x8, .Lcts_permute_table
	sub		x4, x4, #16
	add		x9, x8, #32
	add		x8, x8, x4
	sub		x9, x9, x4
	ld1		{v3.16b}, [x8]
	ld1		{v4.16b}, [x9]

	ld1		{v0.16b}, [x1], x4		/* overlapping loads */
	ld1		{v1.16b}, [x1]

	ld1		{v5.16b}, [x5]			/* get iv */
	dec_prepare	w3, x2, x6

	decrypt_block	v0, w3, x2, x6, w7
	tbl		v2.16b, {v0.16b}, v3.16b
	eor		v2.16b, v2.16b, v1.16b

	tbx		v0.16b, {v1.16b}, v4.16b
	decrypt_block	v0, w3, x2, x6, w7
	eor		v0.16b, v0.16b, v5.16b		/* xor with iv */

	add		x4, x0, x4
	st1		{v2.16b}, [x4]			/* overlapping stores */
	st1		{v0.16b}, [x0]
	ret
AES_FUNC_END(aes_cbc_cts_decrypt)

	.section	".rodata", "a"
	.align		6
.Lcts_permute_table:
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		 0x0,  0x1,  0x2,  0x3,  0x4,  0x5,  0x6,  0x7
	.byte		 0x8,  0x9,  0xa,  0xb,  0xc,  0xd,  0xe,  0xf
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.previous

	/*
	 * This macro generates the code for CTR and XCTR mode.
	 */
.macro ctr_encrypt xctr
	// Arguments
	OUT		.req x0
	IN		.req x1
	KEY		.req x2
	ROUNDS_W	.req w3
	BYTES_W		.req w4
	IV		.req x5
	BYTE_CTR_W 	.req w6		// XCTR only
	// Intermediate values
	CTR_W		.req w11	// XCTR only
	CTR		.req x11	// XCTR only
	IV_PART		.req x12
	BLOCKS		.req x13
	BLOCKS_W	.req w13

	frame_push	0

	enc_prepare	ROUNDS_W, KEY, IV_PART
	ld1		{vctr.16b}, [IV]

	/*
	 * Keep 64 bits of the IV in a register.  For CTR mode this lets us
	 * easily increment the IV.  For XCTR mode this lets us efficiently XOR
	 * the 64-bit counter with the IV.
	 */
	.if \xctr
		umov		IV_PART, vctr.d[0]
		lsr		CTR_W, BYTE_CTR_W, #4
	.else
		umov		IV_PART, vctr.d[1]
		rev		IV_PART, IV_PART
	.endif

.LctrloopNx\xctr:
	add		BLOCKS_W, BYTES_W, #15
	sub		BYTES_W, BYTES_W, #MAX_STRIDE << 4
	lsr		BLOCKS_W, BLOCKS_W, #4
	mov		w8, #MAX_STRIDE
	cmp		BLOCKS_W, w8
	csel		BLOCKS_W, BLOCKS_W, w8, lt

	/*
	 * Set up the counter values in v0-v{MAX_STRIDE-1}.
	 *
	 * If we are encrypting less than MAX_STRIDE blocks, the tail block
	 * handling code expects the last keystream block to be in
	 * v{MAX_STRIDE-1}.  For example: if encrypting two blocks with
	 * MAX_STRIDE=5, then v3 and v4 should have the next two counter blocks.
	 */
	.if \xctr
		add		CTR, CTR, BLOCKS
	.else
		adds		IV_PART, IV_PART, BLOCKS
	.endif
	mov		v0.16b, vctr.16b
	mov		v1.16b, vctr.16b
	mov		v2.16b, vctr.16b
	mov		v3.16b, vctr.16b
ST5(	mov		v4.16b, vctr.16b		)
	.if \xctr
		sub		x6, CTR, #MAX_STRIDE - 1
		sub		x7, CTR, #MAX_STRIDE - 2
		sub		x8, CTR, #MAX_STRIDE - 3
		sub		x9, CTR, #MAX_STRIDE - 4
ST5(		sub		x10, CTR, #MAX_STRIDE - 5	)
		eor		x6, x6, IV_PART
		eor		x7, x7, IV_PART
		eor		x8, x8, IV_PART
		eor		x9, x9, IV_PART
ST5(		eor		x10, x10, IV_PART		)
		mov		v0.d[0], x6
		mov		v1.d[0], x7
		mov		v2.d[0], x8
		mov		v3.d[0], x9
ST5(		mov		v4.d[0], x10			)
	.else
		bcs		0f
		.subsection	1
		/*
		 * This subsection handles carries.
		 *
		 * Conditional branching here is allowed with respect to time
		 * invariance since the branches are dependent on the IV instead
		 * of the plaintext or key.  This code is rarely executed in
		 * practice anyway.
		 */

		/* Apply carry to outgoing counter. */
0:		umov		x8, vctr.d[0]
		rev		x8, x8
		add		x8, x8, #1
		rev		x8, x8
		ins		vctr.d[0], x8

		/*
		 * Apply carry to counter blocks if needed.
		 *
		 * Since the carry flag was set, we know 0 <= IV_PART <
		 * MAX_STRIDE.  Using the value of IV_PART we can determine how
		 * many counter blocks need to be updated.
		 */
		cbz		IV_PART, 2f
		adr		x16, 1f
		sub		x16, x16, IV_PART, lsl #3
		br		x16
		bti		c
		mov		v0.d[0], vctr.d[0]
		bti		c
		mov		v1.d[0], vctr.d[0]
		bti		c
		mov		v2.d[0], vctr.d[0]
		bti		c
		mov		v3.d[0], vctr.d[0]
ST5(		bti		c				)
ST5(		mov		v4.d[0], vctr.d[0]		)
1:		b		2f
		.previous

2:		rev		x7, IV_PART
		ins		vctr.d[1], x7
		sub		x7, IV_PART, #MAX_STRIDE - 1
		sub		x8, IV_PART, #MAX_STRIDE - 2
		sub		x9, IV_PART, #MAX_STRIDE - 3
		rev		x7, x7
		rev		x8, x8
		mov		v1.d[1], x7
		rev		x9, x9
ST5(		sub		x10, IV_PART, #MAX_STRIDE - 4	)
		mov		v2.d[1], x8
ST5(		rev		x10, x10			)
		mov		v3.d[1], x9
ST5(		mov		v4.d[1], x10			)
	.endif

	/*
	 * If there are at least MAX_STRIDE blocks left, XOR the data with
	 * keystream and store.  Otherwise jump to tail handling.
	 */
	tbnz		BYTES_W, #31, .Lctrtail\xctr
	ld1		{v5.16b-v7.16b}, [IN], #48
ST4(	bl		aes_encrypt_block4x		)
ST5(	bl		aes_encrypt_block5x		)
	eor		v0.16b, v5.16b, v0.16b
ST4(	ld1		{v5.16b}, [IN], #16		)
	eor		v1.16b, v6.16b, v1.16b
ST5(	ld1		{v5.16b-v6.16b}, [IN], #32	)
	eor		v2.16b, v7.16b, v2.16b
	eor		v3.16b, v5.16b, v3.16b
ST5(	eor		v4.16b, v6.16b, v4.16b		)
	st1		{v0.16b-v3.16b}, [OUT], #64
ST5(	st1		{v4.16b}, [OUT], #16		)
	cbz		BYTES_W, .Lctrout\xctr
	b		.LctrloopNx\xctr

.Lctrout\xctr:
	.if !\xctr
		st1		{vctr.16b}, [IV] /* return next CTR value */
	.endif
	frame_pop
	ret

.Lctrtail\xctr:
	/*
	 * Handle up to MAX_STRIDE * 16 - 1 bytes of plaintext
	 *
	 * This code expects the last keystream block to be in v{MAX_STRIDE-1}.
	 * For example: if encrypting two blocks with MAX_STRIDE=5, then v3 and
	 * v4 should have the next two counter blocks.
	 *
	 * This allows us to store the ciphertext by writing to overlapping
	 * regions of memory.  Any invalid ciphertext blocks get overwritten by
	 * correctly computed blocks.  This approach greatly simplifies the
	 * logic for storing the ciphertext.
	 */
	mov		x16, #16
	ands		w7, BYTES_W, #0xf
	csel		x13, x7, x16, ne

ST5(	cmp		BYTES_W, #64 - (MAX_STRIDE << 4))
ST5(	csel		x14, x16, xzr, gt		)
	cmp		BYTES_W, #48 - (MAX_STRIDE << 4)
	csel		x15, x16, xzr, gt
	cmp		BYTES_W, #32 - (MAX_STRIDE << 4)
	csel		x16, x16, xzr, gt
	cmp		BYTES_W, #16 - (MAX_STRIDE << 4)

	adr_l		x9, .Lcts_permute_table
	add		x9, x9, x13
	ble		.Lctrtail1x\xctr

ST5(	ld1		{v5.16b}, [IN], x14		)
	ld1		{v6.16b}, [IN], x15
	ld1		{v7.16b}, [IN], x16

ST4(	bl		aes_encrypt_block4x		)
ST5(	bl		aes_encrypt_block5x		)

	ld1		{v8.16b}, [IN], x13
	ld1		{v9.16b}, [IN]
	ld1		{v10.16b}, [x9]

ST4(	eor		v6.16b, v6.16b, v0.16b		)
ST4(	eor		v7.16b, v7.16b, v1.16b		)
ST4(	tbl		v3.16b, {v3.16b}, v10.16b	)
ST4(	eor		v8.16b, v8.16b, v2.16b		)
ST4(	eor		v9.16b, v9.16b, v3.16b		)

ST5(	eor		v5.16b, v5.16b, v0.16b		)
ST5(	eor		v6.16b, v6.16b, v1.16b		)
ST5(	tbl		v4.16b, {v4.16b}, v10.16b	)
ST5(	eor		v7.16b, v7.16b, v2.16b		)
ST5(	eor		v8.16b, v8.16b, v3.16b		)
ST5(	eor		v9.16b, v9.16b, v4.16b		)

ST5(	st1		{v5.16b}, [OUT], x14		)
	st1		{v6.16b}, [OUT], x15
	st1		{v7.16b}, [OUT], x16
	add		x13, x13, OUT
	st1		{v9.16b}, [x13]		// overlapping stores
	st1		{v8.16b}, [OUT]
	b		.Lctrout\xctr

.Lctrtail1x\xctr:
	/*
	 * Handle <= 16 bytes of plaintext
	 *
	 * This code always reads and writes 16 bytes.  To avoid out of bounds
	 * accesses, XCTR and CTR modes must use a temporary buffer when
	 * encrypting/decrypting less than 16 bytes.
	 *
	 * This code is unusual in that it loads the input and stores the output
	 * relative to the end of the buffers rather than relative to the start.
	 * This causes unusual behaviour when encrypting/decrypting less than 16
	 * bytes; the end of the data is expected to be at the end of the
	 * temporary buffer rather than the start of the data being at the start
	 * of the temporary buffer.
	 */
	sub		x8, x7, #16
	csel		x7, x7, x8, eq
	add		IN, IN, x7
	add		OUT, OUT, x7
	ld1		{v5.16b}, [IN]
	ld1		{v6.16b}, [OUT]
ST5(	mov		v3.16b, v4.16b			)
	encrypt_block	v3, ROUNDS_W, KEY, x8, w7
	ld1		{v10.16b-v11.16b}, [x9]
	tbl		v3.16b, {v3.16b}, v10.16b
	sshr		v11.16b, v11.16b, #7
	eor		v5.16b, v5.16b, v3.16b
	bif		v5.16b, v6.16b, v11.16b
	st1		{v5.16b}, [OUT]
	b		.Lctrout\xctr

	// Arguments
	.unreq OUT
	.unreq IN
	.unreq KEY
	.unreq ROUNDS_W
	.unreq BYTES_W
	.unreq IV
	.unreq BYTE_CTR_W	// XCTR only
	// Intermediate values
	.unreq CTR_W		// XCTR only
	.unreq CTR		// XCTR only
	.unreq IV_PART
	.unreq BLOCKS
	.unreq BLOCKS_W
.endm

	/*
	 * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
	 *		   int bytes, u8 ctr[])
	 *
	 * The input and output buffers must always be at least 16 bytes even if
	 * encrypting/decrypting less than 16 bytes.  Otherwise out of bounds
	 * accesses will occur.  The data to be encrypted/decrypted is expected
	 * to be at the end of this 16-byte temporary buffer rather than the
	 * start.
	 */

AES_FUNC_START(aes_ctr_encrypt)
	ctr_encrypt 0
AES_FUNC_END(aes_ctr_encrypt)

	/*
	 * aes_xctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
	 *		   int bytes, u8 const iv[], int byte_ctr)
	 *
	 * The input and output buffers must always be at least 16 bytes even if
	 * encrypting/decrypting less than 16 bytes.  Otherwise out of bounds
	 * accesses will occur.  The data to be encrypted/decrypted is expected
	 * to be at the end of this 16-byte temporary buffer rather than the
	 * start.
	 */

AES_FUNC_START(aes_xctr_encrypt)
	ctr_encrypt 1
AES_FUNC_END(aes_xctr_encrypt)


	/*
	 * aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,
	 *		   int bytes, u8 const rk2[], u8 iv[], int first)
	 * aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,
	 *		   int bytes, u8 const rk2[], u8 iv[], int first)
	 */

	.macro		next_tweak, out, in, tmp
	sshr		\tmp\().2d,  \in\().2d,   #63
	and		\tmp\().16b, \tmp\().16b, xtsmask.16b
	add		\out\().2d,  \in\().2d,   \in\().2d
	ext		\tmp\().16b, \tmp\().16b, \tmp\().16b, #8
	eor		\out\().16b, \out\().16b, \tmp\().16b
	.endm

	.macro		xts_load_mask, tmp
	movi		xtsmask.2s, #0x1
	movi		\tmp\().2s, #0x87
	uzp1		xtsmask.4s, xtsmask.4s, \tmp\().4s
	.endm

AES_FUNC_START(aes_xts_encrypt)
	frame_push	0

	ld1		{v4.16b}, [x6]
	xts_load_mask	v8
	cbz		w7, .Lxtsencnotfirst

	enc_prepare	w3, x5, x8
	xts_cts_skip_tw	w7, .LxtsencNx
	encrypt_block	v4, w3, x5, x8, w7		/* first tweak */
	enc_switch_key	w3, x2, x8
	b		.LxtsencNx

.Lxtsencnotfirst:
	enc_prepare	w3, x2, x8
.LxtsencloopNx:
	next_tweak	v4, v4, v8
.LxtsencNx:
	subs		w4, w4, #64
	bmi		.Lxtsenc1x
	ld1		{v0.16b-v3.16b}, [x1], #64	/* get 4 pt blocks */
	next_tweak	v5, v4, v8
	eor		v0.16b, v0.16b, v4.16b
	next_tweak	v6, v5, v8
	eor		v1.16b, v1.16b, v5.16b
	eor		v2.16b, v2.16b, v6.16b
	next_tweak	v7, v6, v8
	eor		v3.16b, v3.16b, v7.16b
	bl		aes_encrypt_block4x
	eor		v3.16b, v3.16b, v7.16b
	eor		v0.16b, v0.16b, v4.16b
	eor		v1.16b, v1.16b, v5.16b
	eor		v2.16b, v2.16b, v6.16b
	st1		{v0.16b-v3.16b}, [x0], #64
	mov		v4.16b, v7.16b
	cbz		w4, .Lxtsencret
	xts_reload_mask	v8
	b		.LxtsencloopNx
.Lxtsenc1x:
	adds		w4, w4, #64
	beq		.Lxtsencout
	subs		w4, w4, #16
	bmi		.LxtsencctsNx
.Lxtsencloop:
	ld1		{v0.16b}, [x1], #16
.Lxtsencctsout:
	eor		v0.16b, v0.16b, v4.16b
	encrypt_block	v0, w3, x2, x8, w7
	eor		v0.16b, v0.16b, v4.16b
	cbz		w4, .Lxtsencout
	subs		w4, w4, #16
	next_tweak	v4, v4, v8
	bmi		.Lxtsenccts
	st1		{v0.16b}, [x0], #16
	b		.Lxtsencloop
.Lxtsencout:
	st1		{v0.16b}, [x0]
.Lxtsencret:
	st1		{v4.16b}, [x6]
	frame_pop
	ret

.LxtsencctsNx:
	mov		v0.16b, v3.16b
	sub		x0, x0, #16
.Lxtsenccts:
	adr_l		x8, .Lcts_permute_table

	add		x1, x1, w4, sxtw	/* rewind input pointer */
	add		w4, w4, #16		/* # bytes in final block */
	add		x9, x8, #32
	add		x8, x8, x4
	sub		x9, x9, x4
	add		x4, x0, x4		/* output address of final block */

	ld1		{v1.16b}, [x1]		/* load final block */
	ld1		{v2.16b}, [x8]
	ld1		{v3.16b}, [x9]

	tbl		v2.16b, {v0.16b}, v2.16b
	tbx		v0.16b, {v1.16b}, v3.16b
	st1		{v2.16b}, [x4]			/* overlapping stores */
	mov		w4, wzr
	b		.Lxtsencctsout
AES_FUNC_END(aes_xts_encrypt)

AES_FUNC_START(aes_xts_decrypt)
	frame_push	0

	/* subtract 16 bytes if we are doing CTS */
	sub		w8, w4, #0x10
	tst		w4, #0xf
	csel		w4, w4, w8, eq

	ld1		{v4.16b}, [x6]
	xts_load_mask	v8
	xts_cts_skip_tw	w7, .Lxtsdecskiptw
	cbz		w7, .Lxtsdecnotfirst

	enc_prepare	w3, x5, x8
	encrypt_block	v4, w3, x5, x8, w7		/* first tweak */
.Lxtsdecskiptw:
	dec_prepare	w3, x2, x8
	b		.LxtsdecNx

.Lxtsdecnotfirst:
	dec_prepare	w3, x2, x8
.LxtsdecloopNx:
	next_tweak	v4, v4, v8
.LxtsdecNx:
	subs		w4, w4, #64
	bmi		.Lxtsdec1x
	ld1		{v0.16b-v3.16b}, [x1], #64	/* get 4 ct blocks */
	next_tweak	v5, v4, v8
	eor		v0.16b, v0.16b, v4.16b
	next_tweak	v6, v5, v8
	eor		v1.16b, v1.16b, v5.16b
	eor		v2.16b, v2.16b, v6.16b
	next_tweak	v7, v6, v8
	eor		v3.16b, v3.16b, v7.16b
	bl		aes_decrypt_block4x
	eor		v3.16b, v3.16b, v7.16b
	eor		v0.16b, v0.16b, v4.16b
	eor		v1.16b, v1.16b, v5.16b
	eor		v2.16b, v2.16b, v6.16b
	st1		{v0.16b-v3.16b}, [x0], #64
	mov		v4.16b, v7.16b
	cbz		w4, .Lxtsdecout
	xts_reload_mask	v8
	b		.LxtsdecloopNx
.Lxtsdec1x:
	adds		w4, w4, #64
	beq		.Lxtsdecout
	subs		w4, w4, #16
.Lxtsdecloop:
	ld1		{v0.16b}, [x1], #16
	bmi		.Lxtsdeccts
.Lxtsdecctsout:
	eor		v0.16b, v0.16b, v4.16b
	decrypt_block	v0, w3, x2, x8, w7
	eor		v0.16b, v0.16b, v4.16b
	st1		{v0.16b}, [x0], #16
	cbz		w4, .Lxtsdecout
	subs		w4, w4, #16
	next_tweak	v4, v4, v8
	b		.Lxtsdecloop
.Lxtsdecout:
	st1		{v4.16b}, [x6]
	frame_pop
	ret

.Lxtsdeccts:
	adr_l		x8, .Lcts_permute_table

	add		x1, x1, w4, sxtw	/* rewind input pointer */
	add		w4, w4, #16		/* # bytes in final block */
	add		x9, x8, #32
	add		x8, x8, x4
	sub		x9, x9, x4
	add		x4, x0, x4		/* output address of final block */

	next_tweak	v5, v4, v8

	ld1		{v1.16b}, [x1]		/* load final block */
	ld1		{v2.16b}, [x8]
	ld1		{v3.16b}, [x9]

	eor		v0.16b, v0.16b, v5.16b
	decrypt_block	v0, w3, x2, x8, w7
	eor		v0.16b, v0.16b, v5.16b

	tbl		v2.16b, {v0.16b}, v2.16b
	tbx		v0.16b, {v1.16b}, v3.16b

	st1		{v2.16b}, [x4]			/* overlapping stores */
	mov		w4, wzr
	b		.Lxtsdecctsout
AES_FUNC_END(aes_xts_decrypt)

	/*
	 * aes_mac_update(u8 const in[], u32 const rk[], int rounds,
	 *		  int blocks, u8 dg[], int enc_before, int enc_after)
	 */
AES_FUNC_START(aes_mac_update)
	ld1		{v0.16b}, [x4]			/* get dg */
	enc_prepare	w2, x1, x7
	cbz		w5, .Lmacloop4x

	encrypt_block	v0, w2, x1, x7, w8

.Lmacloop4x:
	subs		w3, w3, #4
	bmi		.Lmac1x
	ld1		{v1.16b-v4.16b}, [x0], #64	/* get next pt block */
	eor		v0.16b, v0.16b, v1.16b		/* ..and xor with dg */
	encrypt_block	v0, w2, x1, x7, w8
	eor		v0.16b, v0.16b, v2.16b
	encrypt_block	v0, w2, x1, x7, w8
	eor		v0.16b, v0.16b, v3.16b
	encrypt_block	v0, w2, x1, x7, w8
	eor		v0.16b, v0.16b, v4.16b
	cmp		w3, wzr
	csinv		x5, x6, xzr, eq
	cbz		w5, .Lmacout
	encrypt_block	v0, w2, x1, x7, w8
	st1		{v0.16b}, [x4]			/* return dg */
	cond_yield	.Lmacout, x7, x8
	b		.Lmacloop4x
.Lmac1x:
	add		w3, w3, #4
.Lmacloop:
	cbz		w3, .Lmacout
	ld1		{v1.16b}, [x0], #16		/* get next pt block */
	eor		v0.16b, v0.16b, v1.16b		/* ..and xor with dg */

	subs		w3, w3, #1
	csinv		x5, x6, xzr, eq
	cbz		w5, .Lmacout

.Lmacenc:
	encrypt_block	v0, w2, x1, x7, w8
	b		.Lmacloop

.Lmacout:
	st1		{v0.16b}, [x4]			/* return dg */
	mov		w0, w3
	ret
AES_FUNC_END(aes_mac_update)