/* SPDX-License-Identifier: GPL-2.0-or-later */
/*
 * Core of the accelerated CRC algorithm.
 * In your file, define the constants and CRC_FUNCTION_NAME
 * Then include this file.
 *
 * Calculate the checksum of data that is 16 byte aligned and a multiple of
 * 16 bytes.
 *
 * The first step is to reduce it to 1024 bits. We do this in 8 parallel
 * chunks in order to mask the latency of the vpmsum instructions. If we
 * have more than 32 kB of data to checksum we repeat this step multiple
 * times, passing in the previous 1024 bits.
 *
 * The next step is to reduce the 1024 bits to 64 bits. This step adds
 * 32 bits of 0s to the end - this matches what a CRC does. We just
 * calculate constants that land the data in this 32 bits.
 *
 * We then use fixed point Barrett reduction to compute a mod n over GF(2)
 * for n = CRC using POWER8 instructions. We use x = 32.
 *
 * https://en.wikipedia.org/wiki/Barrett_reduction
 *
 * Copyright (C) 2015 Anton Blanchard <anton@au.ibm.com>, IBM
*/

#include <asm/ppc_asm.h>
#include <asm/ppc-opcode.h>

#define MAX_SIZE	32768

	.text

#if defined(__BIG_ENDIAN__) && defined(REFLECT)
#define BYTESWAP_DATA
#elif defined(__LITTLE_ENDIAN__) && !defined(REFLECT)
#define BYTESWAP_DATA
#else
#undef BYTESWAP_DATA
#endif

#define off16		r25
#define off32		r26
#define off48		r27
#define off64		r28
#define off80		r29
#define off96		r30
#define off112		r31

#define const1		v24
#define const2		v25

#define byteswap	v26
#define	mask_32bit	v27
#define	mask_64bit	v28
#define zeroes		v29

#ifdef BYTESWAP_DATA
#define VPERM(A, B, C, D) vperm	A, B, C, D
#else
#define VPERM(A, B, C, D)
#endif

/* unsigned int CRC_FUNCTION_NAME(unsigned int crc, void *p, unsigned long len) */
FUNC_START(CRC_FUNCTION_NAME)
	std	r31,-8(r1)
	std	r30,-16(r1)
	std	r29,-24(r1)
	std	r28,-32(r1)
	std	r27,-40(r1)
	std	r26,-48(r1)
	std	r25,-56(r1)

	li	off16,16
	li	off32,32
	li	off48,48
	li	off64,64
	li	off80,80
	li	off96,96
	li	off112,112
	li	r0,0

	/* Enough room for saving 10 non volatile VMX registers */
	subi	r6,r1,56+10*16
	subi	r7,r1,56+2*16

	stvx	v20,0,r6
	stvx	v21,off16,r6
	stvx	v22,off32,r6
	stvx	v23,off48,r6
	stvx	v24,off64,r6
	stvx	v25,off80,r6
	stvx	v26,off96,r6
	stvx	v27,off112,r6
	stvx	v28,0,r7
	stvx	v29,off16,r7

	mr	r10,r3

	vxor	zeroes,zeroes,zeroes
	vspltisw v0,-1

	vsldoi	mask_32bit,zeroes,v0,4
	vsldoi	mask_64bit,zeroes,v0,8

	/* Get the initial value into v8 */
	vxor	v8,v8,v8
	MTVRD(v8, R3)
#ifdef REFLECT
	vsldoi	v8,zeroes,v8,8	/* shift into bottom 32 bits */
#else
	vsldoi	v8,v8,zeroes,4	/* shift into top 32 bits */
#endif

#ifdef BYTESWAP_DATA
	LOAD_REG_ADDR(r3, .byteswap_constant)
	lvx	byteswap,0,r3
	addi	r3,r3,16
#endif

	cmpdi	r5,256
	blt	.Lshort

	rldicr	r6,r5,0,56

	/* Checksum in blocks of MAX_SIZE */
1:	lis	r7,MAX_SIZE@h
	ori	r7,r7,MAX_SIZE@l
	mr	r9,r7
	cmpd	r6,r7
	bgt	2f
	mr	r7,r6
2:	subf	r6,r7,r6

	/* our main loop does 128 bytes at a time */
	srdi	r7,r7,7

	/*
	 * Work out the offset into the constants table to start at. Each
	 * constant is 16 bytes, and it is used against 128 bytes of input
	 * data - 128 / 16 = 8
	 */
	sldi	r8,r7,4
	srdi	r9,r9,3
	subf	r8,r8,r9

	/* We reduce our final 128 bytes in a separate step */
	addi	r7,r7,-1
	mtctr	r7

	LOAD_REG_ADDR(r3, .constants)

	/* Find the start of our constants */
	add	r3,r3,r8

	/* zero v0-v7 which will contain our checksums */
	vxor	v0,v0,v0
	vxor	v1,v1,v1
	vxor	v2,v2,v2
	vxor	v3,v3,v3
	vxor	v4,v4,v4
	vxor	v5,v5,v5
	vxor	v6,v6,v6
	vxor	v7,v7,v7

	lvx	const1,0,r3

	/*
	 * If we are looping back to consume more data we use the values
	 * already in v16-v23.
	 */
	cmpdi	r0,1
	beq	2f

	/* First warm up pass */
	lvx	v16,0,r4
	lvx	v17,off16,r4
	VPERM(v16,v16,v16,byteswap)
	VPERM(v17,v17,v17,byteswap)
	lvx	v18,off32,r4
	lvx	v19,off48,r4
	VPERM(v18,v18,v18,byteswap)
	VPERM(v19,v19,v19,byteswap)
	lvx	v20,off64,r4
	lvx	v21,off80,r4
	VPERM(v20,v20,v20,byteswap)
	VPERM(v21,v21,v21,byteswap)
	lvx	v22,off96,r4
	lvx	v23,off112,r4
	VPERM(v22,v22,v22,byteswap)
	VPERM(v23,v23,v23,byteswap)
	addi	r4,r4,8*16

	/* xor in initial value */
	vxor	v16,v16,v8

2:	bdz	.Lfirst_warm_up_done

	addi	r3,r3,16
	lvx	const2,0,r3

	/* Second warm up pass */
	VPMSUMD(v8,v16,const1)
	lvx	v16,0,r4
	VPERM(v16,v16,v16,byteswap)
	ori	r2,r2,0

	VPMSUMD(v9,v17,const1)
	lvx	v17,off16,r4
	VPERM(v17,v17,v17,byteswap)
	ori	r2,r2,0

	VPMSUMD(v10,v18,const1)
	lvx	v18,off32,r4
	VPERM(v18,v18,v18,byteswap)
	ori	r2,r2,0

	VPMSUMD(v11,v19,const1)
	lvx	v19,off48,r4
	VPERM(v19,v19,v19,byteswap)
	ori	r2,r2,0

	VPMSUMD(v12,v20,const1)
	lvx	v20,off64,r4
	VPERM(v20,v20,v20,byteswap)
	ori	r2,r2,0

	VPMSUMD(v13,v21,const1)
	lvx	v21,off80,r4
	VPERM(v21,v21,v21,byteswap)
	ori	r2,r2,0

	VPMSUMD(v14,v22,const1)
	lvx	v22,off96,r4
	VPERM(v22,v22,v22,byteswap)
	ori	r2,r2,0

	VPMSUMD(v15,v23,const1)
	lvx	v23,off112,r4
	VPERM(v23,v23,v23,byteswap)

	addi	r4,r4,8*16

	bdz	.Lfirst_cool_down

	/*
	 * main loop. We modulo schedule it such that it takes three iterations
	 * to complete - first iteration load, second iteration vpmsum, third
	 * iteration xor.
	 */
	.balign	16
4:	lvx	const1,0,r3
	addi	r3,r3,16
	ori	r2,r2,0

	vxor	v0,v0,v8
	VPMSUMD(v8,v16,const2)
	lvx	v16,0,r4
	VPERM(v16,v16,v16,byteswap)
	ori	r2,r2,0

	vxor	v1,v1,v9
	VPMSUMD(v9,v17,const2)
	lvx	v17,off16,r4
	VPERM(v17,v17,v17,byteswap)
	ori	r2,r2,0

	vxor	v2,v2,v10
	VPMSUMD(v10,v18,const2)
	lvx	v18,off32,r4
	VPERM(v18,v18,v18,byteswap)
	ori	r2,r2,0

	vxor	v3,v3,v11
	VPMSUMD(v11,v19,const2)
	lvx	v19,off48,r4
	VPERM(v19,v19,v19,byteswap)
	lvx	const2,0,r3
	ori	r2,r2,0

	vxor	v4,v4,v12
	VPMSUMD(v12,v20,const1)
	lvx	v20,off64,r4
	VPERM(v20,v20,v20,byteswap)
	ori	r2,r2,0

	vxor	v5,v5,v13
	VPMSUMD(v13,v21,const1)
	lvx	v21,off80,r4
	VPERM(v21,v21,v21,byteswap)
	ori	r2,r2,0

	vxor	v6,v6,v14
	VPMSUMD(v14,v22,const1)
	lvx	v22,off96,r4
	VPERM(v22,v22,v22,byteswap)
	ori	r2,r2,0

	vxor	v7,v7,v15
	VPMSUMD(v15,v23,const1)
	lvx	v23,off112,r4
	VPERM(v23,v23,v23,byteswap)

	addi	r4,r4,8*16

	bdnz	4b

.Lfirst_cool_down:
	/* First cool down pass */
	lvx	const1,0,r3
	addi	r3,r3,16

	vxor	v0,v0,v8
	VPMSUMD(v8,v16,const1)
	ori	r2,r2,0

	vxor	v1,v1,v9
	VPMSUMD(v9,v17,const1)
	ori	r2,r2,0

	vxor	v2,v2,v10
	VPMSUMD(v10,v18,const1)
	ori	r2,r2,0

	vxor	v3,v3,v11
	VPMSUMD(v11,v19,const1)
	ori	r2,r2,0

	vxor	v4,v4,v12
	VPMSUMD(v12,v20,const1)
	ori	r2,r2,0

	vxor	v5,v5,v13
	VPMSUMD(v13,v21,const1)
	ori	r2,r2,0

	vxor	v6,v6,v14
	VPMSUMD(v14,v22,const1)
	ori	r2,r2,0

	vxor	v7,v7,v15
	VPMSUMD(v15,v23,const1)
	ori	r2,r2,0

.Lsecond_cool_down:
	/* Second cool down pass */
	vxor	v0,v0,v8
	vxor	v1,v1,v9
	vxor	v2,v2,v10
	vxor	v3,v3,v11
	vxor	v4,v4,v12
	vxor	v5,v5,v13
	vxor	v6,v6,v14
	vxor	v7,v7,v15

#ifdef REFLECT
	/*
	 * vpmsumd produces a 96 bit result in the least significant bits
	 * of the register. Since we are bit reflected we have to shift it
	 * left 32 bits so it occupies the least significant bits in the
	 * bit reflected domain.
	 */
	vsldoi	v0,v0,zeroes,4
	vsldoi	v1,v1,zeroes,4
	vsldoi	v2,v2,zeroes,4
	vsldoi	v3,v3,zeroes,4
	vsldoi	v4,v4,zeroes,4
	vsldoi	v5,v5,zeroes,4
	vsldoi	v6,v6,zeroes,4
	vsldoi	v7,v7,zeroes,4
#endif

	/* xor with last 1024 bits */
	lvx	v8,0,r4
	lvx	v9,off16,r4
	VPERM(v8,v8,v8,byteswap)
	VPERM(v9,v9,v9,byteswap)
	lvx	v10,off32,r4
	lvx	v11,off48,r4
	VPERM(v10,v10,v10,byteswap)
	VPERM(v11,v11,v11,byteswap)
	lvx	v12,off64,r4
	lvx	v13,off80,r4
	VPERM(v12,v12,v12,byteswap)
	VPERM(v13,v13,v13,byteswap)
	lvx	v14,off96,r4
	lvx	v15,off112,r4
	VPERM(v14,v14,v14,byteswap)
	VPERM(v15,v15,v15,byteswap)

	addi	r4,r4,8*16

	vxor	v16,v0,v8
	vxor	v17,v1,v9
	vxor	v18,v2,v10
	vxor	v19,v3,v11
	vxor	v20,v4,v12
	vxor	v21,v5,v13
	vxor	v22,v6,v14
	vxor	v23,v7,v15

	li	r0,1
	cmpdi	r6,0
	addi	r6,r6,128
	bne	1b

	/* Work out how many bytes we have left */
	andi.	r5,r5,127

	/* Calculate where in the constant table we need to start */
	subfic	r6,r5,128
	add	r3,r3,r6

	/* How many 16 byte chunks are in the tail */
	srdi	r7,r5,4
	mtctr	r7

	/*
	 * Reduce the previously calculated 1024 bits to 64 bits, shifting
	 * 32 bits to include the trailing 32 bits of zeros
	 */
	lvx	v0,0,r3
	lvx	v1,off16,r3
	lvx	v2,off32,r3
	lvx	v3,off48,r3
	lvx	v4,off64,r3
	lvx	v5,off80,r3
	lvx	v6,off96,r3
	lvx	v7,off112,r3
	addi	r3,r3,8*16

	VPMSUMW(v0,v16,v0)
	VPMSUMW(v1,v17,v1)
	VPMSUMW(v2,v18,v2)
	VPMSUMW(v3,v19,v3)
	VPMSUMW(v4,v20,v4)
	VPMSUMW(v5,v21,v5)
	VPMSUMW(v6,v22,v6)
	VPMSUMW(v7,v23,v7)

	/* Now reduce the tail (0 - 112 bytes) */
	cmpdi	r7,0
	beq	1f

	lvx	v16,0,r4
	lvx	v17,0,r3
	VPERM(v16,v16,v16,byteswap)
	VPMSUMW(v16,v16,v17)
	vxor	v0,v0,v16
	bdz	1f

	lvx	v16,off16,r4
	lvx	v17,off16,r3
	VPERM(v16,v16,v16,byteswap)
	VPMSUMW(v16,v16,v17)
	vxor	v0,v0,v16
	bdz	1f

	lvx	v16,off32,r4
	lvx	v17,off32,r3
	VPERM(v16,v16,v16,byteswap)
	VPMSUMW(v16,v16,v17)
	vxor	v0,v0,v16
	bdz	1f

	lvx	v16,off48,r4
	lvx	v17,off48,r3
	VPERM(v16,v16,v16,byteswap)
	VPMSUMW(v16,v16,v17)
	vxor	v0,v0,v16
	bdz	1f

	lvx	v16,off64,r4
	lvx	v17,off64,r3
	VPERM(v16,v16,v16,byteswap)
	VPMSUMW(v16,v16,v17)
	vxor	v0,v0,v16
	bdz	1f

	lvx	v16,off80,r4
	lvx	v17,off80,r3
	VPERM(v16,v16,v16,byteswap)
	VPMSUMW(v16,v16,v17)
	vxor	v0,v0,v16
	bdz	1f

	lvx	v16,off96,r4
	lvx	v17,off96,r3
	VPERM(v16,v16,v16,byteswap)
	VPMSUMW(v16,v16,v17)
	vxor	v0,v0,v16

	/* Now xor all the parallel chunks together */
1:	vxor	v0,v0,v1
	vxor	v2,v2,v3
	vxor	v4,v4,v5
	vxor	v6,v6,v7

	vxor	v0,v0,v2
	vxor	v4,v4,v6

	vxor	v0,v0,v4

.Lbarrett_reduction:
	/* Barrett constants */
	LOAD_REG_ADDR(r3, .barrett_constants)

	lvx	const1,0,r3
	lvx	const2,off16,r3

	vsldoi	v1,v0,v0,8
	vxor	v0,v0,v1		/* xor two 64 bit results together */

#ifdef REFLECT
	/* shift left one bit */
	vspltisb v1,1
	vsl	v0,v0,v1
#endif

	vand	v0,v0,mask_64bit
#ifndef REFLECT
	/*
	 * Now for the Barrett reduction algorithm. The idea is to calculate q,
	 * the multiple of our polynomial that we need to subtract. By
	 * doing the computation 2x bits higher (ie 64 bits) and shifting the
	 * result back down 2x bits, we round down to the nearest multiple.
	 */
	VPMSUMD(v1,v0,const1)	/* ma */
	vsldoi	v1,zeroes,v1,8	/* q = floor(ma/(2^64)) */
	VPMSUMD(v1,v1,const2)	/* qn */
	vxor	v0,v0,v1	/* a - qn, subtraction is xor in GF(2) */

	/*
	 * Get the result into r3. We need to shift it left 8 bytes:
	 * V0 [ 0 1 2 X ]
	 * V0 [ 0 X 2 3 ]
	 */
	vsldoi	v0,v0,zeroes,8	/* shift result into top 64 bits */
#else
	/*
	 * The reflected version of Barrett reduction. Instead of bit
	 * reflecting our data (which is expensive to do), we bit reflect our
	 * constants and our algorithm, which means the intermediate data in
	 * our vector registers goes from 0-63 instead of 63-0. We can reflect
	 * the algorithm because we don't carry in mod 2 arithmetic.
	 */
	vand	v1,v0,mask_32bit	/* bottom 32 bits of a */
	VPMSUMD(v1,v1,const1)		/* ma */
	vand	v1,v1,mask_32bit	/* bottom 32bits of ma */
	VPMSUMD(v1,v1,const2)		/* qn */
	vxor	v0,v0,v1		/* a - qn, subtraction is xor in GF(2) */

	/*
	 * Since we are bit reflected, the result (ie the low 32 bits) is in
	 * the high 32 bits. We just need to shift it left 4 bytes
	 * V0 [ 0 1 X 3 ]
	 * V0 [ 0 X 2 3 ]
	 */
	vsldoi	v0,v0,zeroes,4		/* shift result into top 64 bits of */
#endif

	/* Get it into r3 */
	MFVRD(R3, v0)

.Lout:
	subi	r6,r1,56+10*16
	subi	r7,r1,56+2*16

	lvx	v20,0,r6
	lvx	v21,off16,r6
	lvx	v22,off32,r6
	lvx	v23,off48,r6
	lvx	v24,off64,r6
	lvx	v25,off80,r6
	lvx	v26,off96,r6
	lvx	v27,off112,r6
	lvx	v28,0,r7
	lvx	v29,off16,r7

	ld	r31,-8(r1)
	ld	r30,-16(r1)
	ld	r29,-24(r1)
	ld	r28,-32(r1)
	ld	r27,-40(r1)
	ld	r26,-48(r1)
	ld	r25,-56(r1)

	blr

.Lfirst_warm_up_done:
	lvx	const1,0,r3
	addi	r3,r3,16

	VPMSUMD(v8,v16,const1)
	VPMSUMD(v9,v17,const1)
	VPMSUMD(v10,v18,const1)
	VPMSUMD(v11,v19,const1)
	VPMSUMD(v12,v20,const1)
	VPMSUMD(v13,v21,const1)
	VPMSUMD(v14,v22,const1)
	VPMSUMD(v15,v23,const1)

	b	.Lsecond_cool_down

.Lshort:
	cmpdi	r5,0
	beq	.Lzero

	LOAD_REG_ADDR(r3, .short_constants)

	/* Calculate where in the constant table we need to start */
	subfic	r6,r5,256
	add	r3,r3,r6

	/* How many 16 byte chunks? */
	srdi	r7,r5,4
	mtctr	r7

	vxor	v19,v19,v19
	vxor	v20,v20,v20

	lvx	v0,0,r4
	lvx	v16,0,r3
	VPERM(v0,v0,v16,byteswap)
	vxor	v0,v0,v8	/* xor in initial value */
	VPMSUMW(v0,v0,v16)
	bdz	.Lv0

	lvx	v1,off16,r4
	lvx	v17,off16,r3
	VPERM(v1,v1,v17,byteswap)
	VPMSUMW(v1,v1,v17)
	bdz	.Lv1

	lvx	v2,off32,r4
	lvx	v16,off32,r3
	VPERM(v2,v2,v16,byteswap)
	VPMSUMW(v2,v2,v16)
	bdz	.Lv2

	lvx	v3,off48,r4
	lvx	v17,off48,r3
	VPERM(v3,v3,v17,byteswap)
	VPMSUMW(v3,v3,v17)
	bdz	.Lv3

	lvx	v4,off64,r4
	lvx	v16,off64,r3
	VPERM(v4,v4,v16,byteswap)
	VPMSUMW(v4,v4,v16)
	bdz	.Lv4

	lvx	v5,off80,r4
	lvx	v17,off80,r3
	VPERM(v5,v5,v17,byteswap)
	VPMSUMW(v5,v5,v17)
	bdz	.Lv5

	lvx	v6,off96,r4
	lvx	v16,off96,r3
	VPERM(v6,v6,v16,byteswap)
	VPMSUMW(v6,v6,v16)
	bdz	.Lv6

	lvx	v7,off112,r4
	lvx	v17,off112,r3
	VPERM(v7,v7,v17,byteswap)
	VPMSUMW(v7,v7,v17)
	bdz	.Lv7

	addi	r3,r3,128
	addi	r4,r4,128

	lvx	v8,0,r4
	lvx	v16,0,r3
	VPERM(v8,v8,v16,byteswap)
	VPMSUMW(v8,v8,v16)
	bdz	.Lv8

	lvx	v9,off16,r4
	lvx	v17,off16,r3
	VPERM(v9,v9,v17,byteswap)
	VPMSUMW(v9,v9,v17)
	bdz	.Lv9

	lvx	v10,off32,r4
	lvx	v16,off32,r3
	VPERM(v10,v10,v16,byteswap)
	VPMSUMW(v10,v10,v16)
	bdz	.Lv10

	lvx	v11,off48,r4
	lvx	v17,off48,r3
	VPERM(v11,v11,v17,byteswap)
	VPMSUMW(v11,v11,v17)
	bdz	.Lv11

	lvx	v12,off64,r4
	lvx	v16,off64,r3
	VPERM(v12,v12,v16,byteswap)
	VPMSUMW(v12,v12,v16)
	bdz	.Lv12

	lvx	v13,off80,r4
	lvx	v17,off80,r3
	VPERM(v13,v13,v17,byteswap)
	VPMSUMW(v13,v13,v17)
	bdz	.Lv13

	lvx	v14,off96,r4
	lvx	v16,off96,r3
	VPERM(v14,v14,v16,byteswap)
	VPMSUMW(v14,v14,v16)
	bdz	.Lv14

	lvx	v15,off112,r4
	lvx	v17,off112,r3
	VPERM(v15,v15,v17,byteswap)
	VPMSUMW(v15,v15,v17)

.Lv15:	vxor	v19,v19,v15
.Lv14:	vxor	v20,v20,v14
.Lv13:	vxor	v19,v19,v13
.Lv12:	vxor	v20,v20,v12
.Lv11:	vxor	v19,v19,v11
.Lv10:	vxor	v20,v20,v10
.Lv9:	vxor	v19,v19,v9
.Lv8:	vxor	v20,v20,v8
.Lv7:	vxor	v19,v19,v7
.Lv6:	vxor	v20,v20,v6
.Lv5:	vxor	v19,v19,v5
.Lv4:	vxor	v20,v20,v4
.Lv3:	vxor	v19,v19,v3
.Lv2:	vxor	v20,v20,v2
.Lv1:	vxor	v19,v19,v1
.Lv0:	vxor	v20,v20,v0

	vxor	v0,v19,v20

	b	.Lbarrett_reduction

.Lzero:
	mr	r3,r10
	b	.Lout

FUNC_END(CRC_FUNCTION_NAME)