//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Sim/Simulation/ScatteringSimulation.cpp
//! @brief     Implements interface ISimulation.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Sim/Simulation/ScatteringSimulation.h"
#include "Base/Axis/Scale.h"
#include "Base/Pixel/IPixel.h"
#include "Base/Progress/ProgressHandler.h"
#include "Base/Util/Assert.h"
#include "Device/Beam/Beam.h"
#include "Device/Beam/IFootprint.h"
#include "Device/Coord/CoordSystem2D.h"
#include "Device/Data/Datafield.h"
#include "Device/Detector/IDetector.h"
#include "Device/Detector/SimulationAreaIterator.h" // beginNonMaskedPoints
#include "Device/Histo/SimulationResult.h"
#include "Param/Distrib/DistributionHandler.h"
#include "Resample/Element/DiffuseElement.h"
#include "Sim/Background/IBackground.h"
#include "Sim/Computation/DWBAComputation.h"

ScatteringSimulation::ScatteringSimulation(const Beam& beam, const MultiLayer& sample,
                                           const IDetector& detector)
    : ISimulation(sample)
    , m_beam(beam.clone())
    , m_detector(detector.clone())
{
    m_detector->setDetectorNormal(m_beam->ki());
}

ScatteringSimulation::~ScatteringSimulation() = default;

const ICoordSystem* ScatteringSimulation::simCoordSystem() const
{
    return m_detector->scatteringCoords(beam());
}

//... Overridden executors:

//! init callbacks for setting the parameter values
void ScatteringSimulation::initDistributionHandler()
{
    for (const auto& distribution : distributionHandler().paramDistributions()) {

        switch (distribution.whichParameter()) {
        case ParameterDistribution::BeamAzimuthalAngle:
            distributionHandler().defineCallbackForDistribution(
                &distribution, [&](double d) { m_beam->setAzimuthalAngle(d); });
            break;
        case ParameterDistribution::BeamInclinationAngle:
            distributionHandler().defineCallbackForDistribution(
                &distribution, [&](double d) { m_beam->setInclination(d); });
            break;
        case ParameterDistribution::BeamWavelength:
            distributionHandler().defineCallbackForDistribution(
                &distribution, [&](double d) { m_beam->setWavelength(d); });
            break;
        default:
            ASSERT(false);
        }
    }
}

void ScatteringSimulation::prepareSimulation()
{
    m_active_indices = m_detector->active_indices();
    m_pixels.reserve(m_active_indices.size());
    for (auto detector_index : m_active_indices)
        m_pixels.emplace_back(m_detector->createPixel(detector_index));
}

void ScatteringSimulation::runComputation(const ReSample& re_sample, size_t i, double weight)
{
    if (m_cache.empty())
        m_cache.resize(nElements(), 0.0);

    const bool isSpecular = m_active_indices[i] == m_detector->indexOfSpecular(beam());

    DiffuseElement ele(beam().wavelength(), beam().alpha_i(), beam().phi_i(), m_pixels[i],
                       beam().polMatrix(), m_detector->analyzer().matrix(), isSpecular);

    double intensity = Compute::scattered_and_reflected(re_sample, options(), ele);

    if (const auto* footprint = beam().footprint())
        intensity *= footprint->calculate(beam().alpha_i());

    double sin_alpha_i = std::abs(std::sin(beam().alpha_i()));
    if (sin_alpha_i == 0.0) {
        intensity = 0;
    } else {
        const double solid_angle = m_pixels[i]->solidAngle();
        intensity *= m_beam->intensity() * solid_angle / sin_alpha_i;
    }

    if (background())
        intensity = background()->addBackground(intensity);

    m_cache[i] += intensity * weight;

    progress().incrementDone(1);
}

//... Overridden getters:

bool ScatteringSimulation::force_polarized() const
{
    return m_detector->analyzer().BlochVector() != R3{};
}

size_t ScatteringSimulation::nElements() const
{
    return m_active_indices.size();
}

SimulationResult ScatteringSimulation::packResult()
{
    Datafield detectorMap(m_detector->createDetectorMap());
    size_t elementIndex = 0;
    m_detector->iterateOverNonMaskedPoints(
        [&](const auto it) { detectorMap[it.roiIndex()] = m_cache[elementIndex++]; });
    m_detector->applyDetectorResolution(&detectorMap);

    return {detectorMap, simCoordSystem()};
}
