//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//============================================================================

#include "Benchmarker.h"

#include <vtkm/TypeTraits.h>

#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/DeviceAdapterAlgorithm.h>
#include <vtkm/cont/Initialize.h>
#include <vtkm/cont/Timer.h>

#include <vtkm/source/Tangle.h>

#include <vtkm/rendering/Camera.h>
#include <vtkm/rendering/raytracing/Ray.h>
#include <vtkm/rendering/raytracing/RayTracer.h>
#include <vtkm/rendering/raytracing/SphereIntersector.h>
#include <vtkm/rendering/raytracing/TriangleExtractor.h>

#include <vtkm/exec/FunctorBase.h>

#include <sstream>
#include <string>
#include <vector>

namespace
{

// Hold configuration state (e.g. active device)
vtkm::cont::InitializeResult Config;

void BenchRayTracing(::benchmark::State& state)
{
  vtkm::source::Tangle maker;
  maker.SetPointDimensions({ 128, 128, 128 });
  vtkm::cont::DataSet dataset = maker.Execute();
  vtkm::cont::CoordinateSystem coords = dataset.GetCoordinateSystem();

  vtkm::rendering::Camera camera;
  vtkm::Bounds bounds = dataset.GetCoordinateSystem().GetBounds();
  camera.ResetToBounds(bounds);

  vtkm::cont::UnknownCellSet cellset = dataset.GetCellSet();

  vtkm::rendering::raytracing::TriangleExtractor triExtractor;
  triExtractor.ExtractCells(cellset);

  auto triIntersector = std::make_shared<vtkm::rendering::raytracing::TriangleIntersector>(
    vtkm::rendering::raytracing::TriangleIntersector());

  vtkm::rendering::raytracing::RayTracer tracer;
  triIntersector->SetData(coords, triExtractor.GetTriangles());
  tracer.AddShapeIntersector(triIntersector);

  vtkm::rendering::CanvasRayTracer canvas(1920, 1080);
  vtkm::rendering::raytracing::Camera rayCamera;
  rayCamera.SetParameters(camera, vtkm::Int32(canvas.GetWidth()), vtkm::Int32(canvas.GetHeight()));
  vtkm::rendering::raytracing::Ray<vtkm::Float32> rays;
  rayCamera.CreateRays(rays, coords.GetBounds());

  rays.Buffers.at(0).InitConst(0.f);

  vtkm::cont::Field field = dataset.GetField("tangle");
  vtkm::Range range = field.GetRange().ReadPortal().Get(0);

  tracer.SetField(field, range);

  vtkm::cont::ArrayHandle<vtkm::Vec4ui_8> temp;
  vtkm::cont::ColorTable table("cool to warm");
  table.Sample(100, temp);

  vtkm::cont::ArrayHandle<vtkm::Vec4f_32> colors;
  colors.Allocate(100);
  auto portal = colors.WritePortal();
  auto colorPortal = temp.ReadPortal();
  constexpr vtkm::Float32 conversionToFloatSpace = (1.0f / 255.0f);
  for (vtkm::Id i = 0; i < 100; ++i)
  {
    auto color = colorPortal.Get(i);
    vtkm::Vec4f_32 t(color[0] * conversionToFloatSpace,
                     color[1] * conversionToFloatSpace,
                     color[2] * conversionToFloatSpace,
                     color[3] * conversionToFloatSpace);
    portal.Set(i, t);
  }

  tracer.SetColorMap(colors);
  tracer.Render(rays);

  vtkm::cont::Timer timer{ Config.Device };
  for (auto _ : state)
  {
    (void)_;
    timer.Start();
    rayCamera.CreateRays(rays, coords.GetBounds());
    tracer.Render(rays);
    timer.Stop();

    state.SetIterationTime(timer.GetElapsedTime());
  }
}

VTKM_BENCHMARK(BenchRayTracing);

} // end namespace vtkm::benchmarking

int main(int argc, char* argv[])
{
  auto opts = vtkm::cont::InitializeOptions::RequireDevice;

  std::vector<char*> args(argv, argv + argc);
  vtkm::bench::detail::InitializeArgs(&argc, args, opts);

  // Parse VTK-m options:
  Config = vtkm::cont::Initialize(argc, args.data(), opts);

  // This occurs when it is help
  if (opts == vtkm::cont::InitializeOptions::None)
  {
    std::cout << Config.Usage << std::endl;
  }
  else
  {
    vtkm::cont::GetRuntimeDeviceTracker().ForceDevice(Config.Device);
  }

  // handle benchmarking related args and run benchmarks:
  VTKM_EXECUTE_BENCHMARKS(argc, args.data());
}
