//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Resample/Specular/TransitionMagneticNevot.cpp
//! @brief     Implements namespace MagneticNevotCroceTransition.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2020
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Resample/Specular/TransitionMagneticNevot.h"
#include "Base/Spin/SpinMatrix.h"
#include "Base/Util/Assert.h"
#include "Resample/Flux/MatrixFlux.h"

//! Returns refraction matrix blocks s_{ab}^+-.
//! See PhysRef, chapter "Polarized", section "Nevot-Croce approximation".
std::pair<SpinMatrix, SpinMatrix>
Compute::refractionMatrixBlocksNevot(const MatrixFlux& TR_a, const MatrixFlux& TR_b, double sigma)
{
    ASSERT(sigma > 0);

    auto roughness_matrix = [&TR_a, &TR_b, sigma](double sign) -> SpinMatrix {
        complex_t alpha_a = TR_a.k_eigen_up() + TR_a.k_eigen_dn();
        complex_t alpha_b = TR_b.k_eigen_up() + TR_b.k_eigen_dn();
        complex_t beta_a = TR_a.k_eigen_up() - TR_a.k_eigen_dn();
        complex_t beta_b = TR_b.k_eigen_up() - TR_b.k_eigen_dn();

        const complex_t alpha = alpha_b + sign * alpha_a;
        C3 b = beta_b * TR_b.field() + sign * beta_a * TR_a.field();

        auto square = [](auto& v) { return v.x() * v.x() + v.y() * v.y() + v.z() * v.z(); };
        complex_t beta = std::sqrt(square(b));
        if (std::abs(beta) < std::numeric_limits<double>::epsilon() * 10.) {
            const complex_t alpha_pp = -(alpha * alpha) * sigma * sigma / 8.;
            return SpinMatrix(std::exp(alpha_pp), 0, 0, std::exp(alpha_pp));
        }

        b /= beta;

        const complex_t alpha_pp = -(alpha * alpha + beta * beta) * sigma * sigma / 8.;
        const complex_t beta_pp = -alpha * beta * sigma * sigma / 4.;
        SpinMatrix Q(b.z() + 1., b.x() - I * b.y(), b.x() + I * b.y(), -1. - b.z());
        const SpinMatrix M(std::exp(beta_pp), 0, 0, std::exp(-beta_pp));

        return std::exp(alpha_pp) * Q * M * Q.adjoint() / (2. * (1. + b.z()));
    };

    const auto kk = SpinMatrix(TR_a.computeInverseKappa() * TR_b.computeKappa());
    const auto sp = 0.5 * (SpinMatrix::One() + kk) * roughness_matrix(-1.);
    const auto sm = 0.5 * (SpinMatrix::One() - kk) * roughness_matrix(+1.);

    return {sp, sm};
}
