/**
 *
 * @file z_rradd_tests.c
 *
 * Tests and validate the core_zrradd() routine.
 *
 * @copyright 2015-2018 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
 *                      Univ. Bordeaux. All rights reserved.
 *
 * @version 6.0.1
 * @author Gregoire Pichon
 * @date 2018-07-16
 *
 * @precisions normal z -> c d s
 *
 **/
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <string.h>
#include <assert.h>
#include <time.h>
#include <pastix.h>
#include "common/common.h"
#include <lapacke.h>
#include <cblas.h>
#include "blend/solver.h"
#include "kernels/pastix_zcores.h"
#include "kernels/pastix_zlrcores.h"
#include "z_tests.h"

#define PRINT_RES(_ret_)                        \
    if(_ret_ == -1) {                           \
        printf("UNDEFINED\n");                  \
    }                                           \
    else if(_ret_ > 0) {                        \
        printf("FAILED(%d)\n", _ret_);          \
        err++;                                  \
    }                                           \
    else {                                      \
        printf("SUCCESS\n");                    \
    }

int
z_rradd_test( int mode, int use_reltol, double tolerance,
              pastix_int_t offx, pastix_int_t offy,
              test_matrix_t *A, test_matrix_t *B )
{
    test_matrix_t C;
    pastix_complex64_t *Cfr;
    pastix_complex64_t zalpha = -1.;
    double threshold = tolerance * tolerance;
    int i, ret, rc = 0;
    pastix_lr_t lowrank = {
        .compress_when       = PastixCompressWhenEnd,
        .compress_method     = PastixCompressMethodPQRCP,
        .compress_min_width  = 0,
        .compress_min_height = 0,
        .use_reltol          = use_reltol,
        .tolerance           = tolerance,
        .core_ge2lr          = core_zge2lr_svd,
        .core_rradd          = core_zrradd_svd,
    };

    /*
     * Lets' generate the test matrices:
     *   1) Generate a matrix of a given rank in dense
     *   2) Compress them with any of the compression kernels (Here PQRCP)
     *   3) Uncompress them to check only the loss generated by the rradd kernel
     */
    z_lowrank_genmat_comp( &lowrank, mode, threshold, A );
    z_lowrank_genmat_comp( &lowrank, mode, threshold, B );

    /*
     * Perform C = B - A in full rank format
     */
    C.m  = B->m;
    C.n  = B->n;
    C.rk = B->rk;
    C.ld = B->ld;
    Cfr = malloc( C.n * C.ld * sizeof(pastix_complex64_t) );
    C.fr = Cfr;

    /* Copy B into C */
    LAPACKE_zlacpy_work( LAPACK_COL_MAJOR, 'A', C.m, C.n,
                         B->fr, B->ld, Cfr, C.ld );

    /* C = B + \alpha * A */
    core_zgeadd( PastixNoTrans, A->m, A->n,
                 zalpha, A->fr,                    A->ld,
                    1.0, Cfr + offx + C.ld * offy, C.ld );

    C.norm = LAPACKE_zlange_work( LAPACK_COL_MAJOR, 'f', C.m, C.n,
                                  Cfr, C.ld, NULL );

    /* Compress C with the same citeria to get an upper bound of the rank */
    lowrank.core_ge2lr( use_reltol, tolerance, pastix_imin( C.m, C.n ),
                        C.m, C.n, C.fr, C.ld, &(C.lr) );

    fprintf( stdout, "%7s %4s %12s %12s %12s %12s (RankC=%d)\n",
             "Method", "Rank", "Time", "||B-A||_f", "||c(c(B)-c(A))-(B-A)||_f",
             "||c(B-A)-(B-A)||_f/(||B-A||_f * eps)", C.lr.rk );
    core_zlrfree( &(C.lr) );

    /* Let's test all methods we have */
    for(i=0; i<PastixCompressMethodNbr; i++)
    {
        lowrank.compress_method = i;
        lowrank.core_ge2lr = ge2lrMethods[i][PastixComplex64-2];
        lowrank.core_rradd = rraddMethods[i][PastixComplex64-2];
        ret = z_lowrank_check_rradd( &lowrank,
                                     offx, offy, zalpha,
                                     A, B, &C );
        rc += ret * (1 << i);
    }

    core_zlrfree( &(A->lr) );
    core_zlrfree( &(B->lr) );
    free(A->fr);
    free(B->fr);
    free(C.fr);
    return rc;
}

int z_rradd_long()
{
    int err = 0;
    int ret;
    int all_ma[] = { 0, 17, 32,  87 };
    int all_na[] = { 0,  4, 17,  87 };
    int all_ra[] = { 0,  4, 16 };
    int all_mb[] = { 0,  4, 48, 234, 432 };
    int all_nb[] = { 0,  2, 48, 170, 432 };
    int all_rb[] = { 0,  3, 15,  35 };
    int all_offx[] = { 0, 12, 57 };
    int all_offy[] = { 0, 7,  37 };

    int nb_ma   = sizeof( all_ma   ) / sizeof( int );
    int nb_ra   = sizeof( all_ra   ) / sizeof( int );
    int nb_mb   = sizeof( all_mb   ) / sizeof( int );
    int nb_rb   = sizeof( all_rb   ) / sizeof( int );
    int nb_offx = sizeof( all_offx ) / sizeof( int );

    int    ima, ira, imb, irb, ioffx;
    int    use_reltol;
    double eps = LAPACKE_dlamch_work('e');
    double tolerance = sqrt(eps);

    test_matrix_t A, B;

    for (ima=0; ima<nb_ma; ima++) {
        int skipra = 0;
        A.m  = all_ma[ima];
        A.n  = all_na[ima];
        A.ld = all_ma[ima];

        for (ira=pastix_imax(0,ima-2); ira<nb_ra; ira++) {
            A.rk = all_ra[ira];
            if ( skipra ) {
                continue;
            }
            if ( A.rk > pastix_imin( A.m, A.n ) ) {
                A.rk = pastix_imin( A.m, A.n );
                skipra = 1;
            }

            for (imb=0; imb<nb_mb; imb++) {
                int skiprb = 0;
                B.m  = all_mb[imb];
                B.n  = all_nb[imb];
                B.ld = all_mb[imb];

                for (irb=pastix_imax(0,imb-2); irb<nb_rb; irb++) {
                    B.rk = all_rb[irb];

                    if ( skiprb ) {
                        continue;
                    }
                    if ( B.rk > pastix_imin( B.m, B.n ) ) {
                        B.rk = pastix_imin( B.m, B.n );
                        skiprb = 1;
                    }

                    if ( (A.rk+B.rk) > pastix_imin( B.m, B.n ) ) {
                        continue;
                    }

                    for (ioffx=0; ioffx<nb_offx; ioffx++) {
                        int offx = all_offx[ioffx];
                        int offy = all_offy[ioffx];

                        if ( A.m + offx > B.m ) {
                            continue;
                        }
                        if ( A.n + offy > B.n ) {
                            continue;
                        }

                        for (use_reltol=0; use_reltol < 2; use_reltol++ ) {
                            printf( "  -- Test RRADD MA=LDA=%d, NA=%d, RA=%d, MB=LDB=%d, NB=%d, RB=%d, rkmax=%ld, %s\n",
                                    A.m, A.n, A.rk, B.m, B.n, B.rk,
                                    (long)core_get_rklimit( B.m, B.n ), use_reltol ? "relative" : "absolute" );

                            ret = z_rradd_test( 0, use_reltol, tolerance, offx, offy, &A, &B );
                            PRINT_RES(ret);
                        }
                    }
                }
            }
        }
    }

    if( err == 0 ) {
        printf(" -- All tests PASSED --\n");
        return EXIT_SUCCESS;
    }
    else
    {
        printf(" -- %d tests FAILED --\n", err);
        return EXIT_FAILURE;
    }
}

int z_rradd_short()
{
    int err = 0;
    int ret;

    /**
     * Let's test multiple cases:
     *   1) a) Addition of 2 null matrices
     *      b) Addition of a null matrix to a non-null one
     *   2) a) Addition of a null-rank matrix to a non null-rank one
     *      b) Addition of a non null-rank matrix to a null-rank one
     *   3) Addition of a two non null-rank matrix
     *      a) A smaller on both dimensions than B
     *         0) A.rk + B.rk < rklimit
     *         1) A.rk + B.rk > rklimit
     *      b/c) A smaller on one dimension than B
     *         0) A.rk + B.rk < rklimit
     *         1) A.rk + B.rk > rklimit
     *      d) A of same size as B
     *         0) A.rk + B.rk < rklimit
     *         1) A.rk + B.rk > rklimit
     *             If they share singular values, we should observe recompression
     *             otherwise, we should fall back to full rank matrix
     */
    int all_ma[]   = { 0, 0, 10, 10,  32,  32, 125, 125,  37,  37, 125, 125 };
    int all_na[]   = { 0, 0, 14, 14,  25,  25,  25,  25, 102, 102, 102, 102 };
    int all_ra[]   = { 0, 0,  0,  4,   7,  10,   7,  10,   7,  10,   6,  25 };
    int all_mb[]   = { 0, 4, 25, 25, 125, 125, 125, 125, 125, 125, 125, 125 };
    int all_nb[]   = { 0, 3, 17, 17, 102, 102, 100, 100, 102, 102, 102, 102 };
    int all_rb[]   = { 0, 2,  5,  0,  15,  25,  15,  25,  15,  25,  15,  25 };
    int all_offx[] = { 0, 1,  7,  7,  67,  38,   0,   0,   0,  88,   0,   0 };
    int all_offy[] = { 0, 0,  3,  1,  47,  77,  18,  75,   0,   0,   0,   0 };

    int nb_tests = sizeof( all_ma ) / sizeof( int );
    int    i;
    int    use_reltol;
    double eps = LAPACKE_dlamch_work('e');
    double tolerance = sqrt(eps);

    test_matrix_t A, B;

    for (i=0; i<nb_tests; i++) {
        int offx = all_offx[i];
        int offy = all_offy[i];
        A.m  = all_ma[i];
        A.n  = all_na[i];
        A.rk = all_ra[i];
        A.ld = all_ma[i];
        B.m  = all_mb[i];
        B.n  = all_nb[i];
        B.rk = all_rb[i];
        B.ld = all_mb[i];

        for (use_reltol=0; use_reltol < 2; use_reltol++ ) {
            printf( "  -- Test RRADD MA=LDA=%d, NA=%d, RA=%d, MB=LDB=%d, NB=%d, RB=%d, rkmax=%ld, %s\n",
                    A.m, A.n, A.rk, B.m, B.n, B.rk,
                    (long)core_get_rklimit( B.m, B.n ), use_reltol ? "relative" : "absolute" );

            ret = z_rradd_test( 0, use_reltol, tolerance, offx, offy, &A, &B );
            PRINT_RES(ret);
        }
    }

    if( err == 0 ) {
        printf(" -- All tests PASSED --\n");
        return EXIT_SUCCESS;
    }
    else
    {
        printf(" -- %d tests FAILED --\n", err);
        return EXIT_FAILURE;
    }
}

int main( int argc, char *argv[] ) {
    if ( argc > 1 ) {
        return z_rradd_short();
    }
    else {
        return z_rradd_long();
    }
    (void)argv;
}
