// Copyright 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "../common/ManagedObject.h"
#include "../common/simd.h"
#include "../iterator/Iterator.h"
#include "../iterator/IteratorContext.h"
#include "../observer/Observer.h"
#include "openvkl/openvkl.h"
#include "rkcommon/utility/getEnvVar.h"

#include "../../common/StructShared.h"
#include "SamplerShared.h"

namespace openvkl {
  namespace cpu_device {

    // Helpers ////////////////////////////////////////////////////////////////

    inline void assertValidTime(const float time)
    {
#ifndef NDEBUG
      assert(time >= 0.f && time <= 1.0f);
#endif
    }

    template <int W>
    inline void assertValidTimes(const vintn<W> &valid, const vfloatn<W> &time)
    {
#ifndef NDEBUG
      for (auto i = 0; i < W; ++i) {
        if (!valid[i]) {
          continue;
        }
        assert(time[i] >= 0.f && time[i] <= 1.0f);
      }
#endif
    }

    inline void assertAllValidTimes(unsigned int N, const float *times)
    {
#ifndef NDEBUG
      for (auto i = 0; i < N; ++i) {
        assert(times == nullptr || (times[i] >= 0.f && times[i] <= 1.0f));
      }
#endif
    }

    template <typename VolumeType>
    inline void assertValidAttributeIndices(
        const VolumeType &volume,
        unsigned int M,
        const unsigned int *attributeIndices)
    {
#ifndef NDEBUG
      for (auto i = 0; i < M; ++i) {
        assert(attributeIndices[i] < volume->getNumAttributes());
      }
#endif
    }

    // Sampler ////////////////////////////////////////////////////////////////

    template <int W>
    class Volume;

    template <int W>
    struct Sampler : public AddStructShared<ManagedObject, ispc::SamplerShared>
    {
      Sampler(Device *device) : AddStructShared<ManagedObject, ispc::SamplerShared>(device) {}  // not = default, due to ICC 19 compiler bug
      Sampler(Sampler &&) = delete;
      Sampler &operator=(Sampler &&) = delete;
      Sampler(const Sampler &)       = delete;
      Sampler &operator=(const Sampler &) = delete;

      virtual ~Sampler();

      virtual VKLFeatureFlagsInternal getFeatureFlags() const = 0;

      // single attribute /////////////////////////////////////////////////////

      // samplers can optionally define a scalar sampling method; if not
      // defined then the default implementation will use computeSampleV()
      virtual void computeSample(const vvec3fn<1> &objectCoordinates,
                                 vfloatn<1> &samples,
                                 unsigned int attributeIndex,
                                 const vfloatn<1> &times) const;

      virtual void computeSampleV(const vintn<W> &valid,
                                  const vvec3fn<W> &objectCoordinates,
                                  vfloatn<W> &samples,
                                  unsigned int attributeIndex,
                                  const vfloatn<W> &times) const = 0;

      virtual void computeSampleN(unsigned int N,
                                  const vvec3fn<1> *objectCoordinates,
                                  float *samples,
                                  unsigned int attributeIndex,
                                  const float *times) const = 0;

      virtual void computeGradientV(const vintn<W> &valid,
                                    const vvec3fn<W> &objectCoordinates,
                                    vvec3fn<W> &gradients,
                                    unsigned int attributeIndex,
                                    const vfloatn<W> &times) const = 0;

      virtual void computeGradientN(unsigned int N,
                                    const vvec3fn<1> *objectCoordinates,
                                    vvec3fn<1> *gradients,
                                    unsigned int attributeIndex,
                                    const float *times) const = 0;

      // multi-attribute //////////////////////////////////////////////////////

      virtual void computeSampleM(const vvec3fn<1> &objectCoordinates,
                                  float *samples,
                                  unsigned int M,
                                  const unsigned int *attributeIndices,
                                  const vfloatn<1> &times) const;

      virtual void computeSampleMV(const vintn<W> &valid,
                                   const vvec3fn<W> &objectCoordinates,
                                   float *samples,
                                   unsigned int M,
                                   const unsigned int *attributeIndices,
                                   const vfloatn<W> &times) const;

      virtual void computeSampleMN(unsigned int N,
                                   const vvec3fn<1> *objectCoordinates,
                                   float *samples,
                                   unsigned int M,
                                   const unsigned int *attributeIndices,
                                   const float *times) const;

      virtual Observer<W> *newObserver(const char *type) = 0;

      /*
       * Samplers keep references to their underlying volumes!
       */
      virtual Volume<W> &getVolume()             = 0;
      virtual const Volume<W> &getVolume() const = 0;

      /*
       * Return the iterator factories for this volume.
       */
      virtual const IteratorFactory<W,
                                    IntervalIterator,
                                    IntervalIteratorContext>
          &getIntervalIteratorFactory() const = 0;

      virtual const IteratorFactory<W, HitIterator, HitIteratorContext>
          &getHitIteratorFactory() const = 0;

     protected:
      /*
       * Return if specialization constants (feature flags) should be disabled.
       */
      bool isSpecConstsDisabled() const;
    };

    // Inlined definitions ////////////////////////////////////////////////////

    template <int W>
    inline void Sampler<W>::computeSample(const vvec3fn<1> &objectCoordinates,
                                          vfloatn<1> &samples,
                                          unsigned int attributeIndex,
                                          const vfloatn<1> &times) const
    {
      // gracefully degrade to use computeSampleV(); see
      // CPUDevice<W>::computeSampleAnyWidth()

      vvec3fn<W> ocW = static_cast<vvec3fn<W>>(objectCoordinates);
      vfloatn<W> tW  = static_cast<vfloatn<W>>(times);

      vintn<W> validW;
      for (int i = 0; i < W; i++)
        validW[i] = i == 0 ? 1 : 0;

      ocW.fill_inactive_lanes(validW);
      tW.fill_inactive_lanes(validW);

      vfloatn<W> samplesW;

      computeSampleV(validW, ocW, samplesW, attributeIndex, tW);

      samples[0] = samplesW[0];
    }

    template <int W>
    inline void Sampler<W>::computeSampleM(const vvec3fn<1> &objectCoordinates,
                                           float *samples,
                                           unsigned int M,
                                           const unsigned int *attributeIndices,
                                           const vfloatn<1> &times) const
    {
      for (unsigned int a = 0; a < M; a++) {
        computeSample(objectCoordinates,
                      reinterpret_cast<vfloatn<1> &>(samples[a]),
                      a,
                      times);
      }
    }

    template <int W>
    inline void Sampler<W>::computeSampleMV(
        const vintn<W> &valid,
        const vvec3fn<W> &objectCoordinates,
        float *samples,
        unsigned int M,
        const unsigned int *attributeIndices,
        const vfloatn<W> &times) const
    {
      for (unsigned int a = 0; a < M; a++) {
        vfloatn<W> samplesW;

        computeSampleV(
            valid, objectCoordinates, samplesW, attributeIndices[a], times);

        for (int i = 0; i < W; i++)
          samples[a * W + i] = samplesW[i];
      }
    }

    template <int W>
    inline void Sampler<W>::computeSampleMN(
        unsigned int N,
        const vvec3fn<1> *objectCoordinates,
        float *samples,
        unsigned int M,
        const unsigned int *attributeIndices,
        const float *times) const
    {
      std::vector<float> samplesN(N);

      for (unsigned int a = 0; a < M; a++) {
        computeSampleN(
            N, objectCoordinates, samplesN.data(), attributeIndices[a], times);

        for (unsigned int i = 0; i < N; i++)
          samples[i * M + a] = samplesN[i];
      }
    }

    template <int W>
    inline bool Sampler<W>::isSpecConstsDisabled() const
    {
      return bool(rkcommon::utility::getEnvVar<int>(
                      "OPENVKL_GPU_DEVICE_DEBUG_DISABLE_SPEC_CONST")
                      .value_or(0));
    }

    ///////////////////////////////////////////////////////////////////////////

    // SamplerBase is the base class for all concrete sampler types.
    // It takes care of keeping a reference to the volume, and provides
    // sensible default implementation where possible.

    template <int W,
              template <int>
              class VolumeT,
              template <int>
              class IntervalIteratorFactory,
              template <int>
              class HitIteratorFactory>
    struct SamplerBase
        : public AddStructShared<Sampler<W>, ispc::SamplerBaseShared>
    {
      explicit SamplerBase(Device *device, VolumeT<W> &volume)
          : AddStructShared<Sampler<W>, ispc::SamplerBaseShared>(device),
            volume(&volume)
      {
      }

      VolumeT<W> &getVolume() override
      {
        return *volume;
      }

      const VolumeT<W> &getVolume() const override
      {
        return *volume;
      }

      Observer<W> *newObserver(const char *type) override
      {
        /*
         * This is a place to provide potential default observers that
         * work for *all* samplers, if we ever find such a thing.
         */
        return nullptr;
      }

      const IteratorFactory<W, IntervalIterator, IntervalIteratorContext>
          &getIntervalIteratorFactory() const override final
      {
        return intervalIteratorFactory;
      }

      const IteratorFactory<W, HitIterator, HitIteratorContext>
          &getHitIteratorFactory() const override final
      {
        return hitIteratorFactory;
      }

     protected:
      rkcommon::memory::Ref<VolumeT<W>> volume;
      IntervalIteratorFactory<W> intervalIteratorFactory;
      HitIteratorFactory<W> hitIteratorFactory;
    };

  }  // namespace cpu_device
}  // namespace openvkl
