import { GalaxyConfigurator } from './GalaxyConfigurator'
import { OrbitControls } from 'src/core/rendering/OrbitControls'
import { Color } from 'src/core/Color'
import { Vector3 } from 'src/core/math/Vector3'
import { UniformStruct } from 'src/core/webGPU/Uniforms/UniformStruct'
import { Attributes } from 'src/core/webGPU/Attributes'
import { randomPointOnSphere } from 'src/core/math/randomUtils'
import { Plane } from 'src/core/math/Plane'
import { BindGroup } from 'src/core/webGPU/BindGroup'
import { WebGPUProgram } from 'src/core/webGPU/WebGPUProgram'
import { AttributesBind } from 'src/core/webGPU/AttributesBind'

import starsComputeWGSL from './shaders/stars.compute.wgsl'
import starsRenderWGSL from './shaders/stars.render.wgsl'
import { ComputePipeline } from 'src/core/webGPU/ComputePipeline'
import { BindGroupArray } from 'src/core/webGPU/BindGroupArray'

class CameraStruct extends UniformStruct {
  matrix = this.addMatrix4Field()
  up = this.addVector3Field()
  right = this.addVector3Field()
}

class ConfigStruct extends UniformStruct {
  starSizeMin = this.addFloatField()
  starSizeMax = this.addFloatField()
  color1 = this.addFloatRGBColorField()
  color2 = this.addFloatRGBColorField()
}

class ComputeConfig extends UniformStruct {
  time = this.addFloatField()
  frameTime = this.addFloatField()
  gravity = this.addFloatField()
  center = this.addVector3Field()
}

export class GalaxyGenerator4 extends WebGPUProgram {
  private _configurator?: GalaxyConfigurator
  private _orbitControls?: OrbitControls
  private _renderPipeline: GPURenderPipeline
  private _computePipeline: ComputePipeline

  private _center = new Vector3(0, 0, 0)
  private _cameraStruct = new CameraStruct()
  private _renderConfigStruct = new ConfigStruct()
  private _computeConfig = new ComputeConfig()
  private _renderUniformGroup = new BindGroup([
    this._cameraStruct,
    this._renderConfigStruct,
  ])

  private _dynamicAttributes = [
    Attributes.fromTypes([
      'float32x4', // position
      'float32x4', // velocity
    ]),
    Attributes.fromTypes([
      'float32x4', // position
      'float32x4', // velocity
    ]),
  ]

  private _staticAttributes = Attributes.fromTypes([
    'float32', // size
  ])

  private _computeBindGroups = new BindGroupArray([
    new BindGroup(
      [
        this._computeConfig,
        new AttributesBind(this._dynamicAttributes[0], { readOnly: true }),
        new AttributesBind(this._dynamicAttributes[1], { readOnly: false }),
      ],
      GPUShaderStage.COMPUTE
    ),
    new BindGroup(
      [
        this._computeConfig,
        new AttributesBind(this._dynamicAttributes[1], { readOnly: true }),
        new AttributesBind(this._dynamicAttributes[0], { readOnly: false }),
      ],
      GPUShaderStage.COMPUTE
    ),
  ])

  async initialize(canvas: HTMLCanvasElement) {
    await super.initialize(canvas)
    this._configurator = new GalaxyConfigurator()
    this._orbitControls = new OrbitControls(this._viewport)
    // look from 45 degrees above
    this._orbitControls.camera.setValues(
      new Vector3(0, 1, -1).normalize(),
      new Vector3(0, 0, 0),
      new Vector3(0, 1, 1).normalize()
    )

    this.createRenderPipeline()
    this.createStarsBuffers()
    this.createComputePipeline()
    
    this._configurator.onChange = this.optionsChangeHandler
    this._viewport.onPointerMove.add((e) => {
      if (e.ctrl) {
        const camera = this._orbitControls.camera
        const direction = camera.getDirection()
        const ray = camera.getRay(e.position)
        const plane = Plane.fromDirectionPosition(direction, this._center)
        const distance = ray.distanceToPlane(plane)
        if (isFinite(distance)) {
          const position = ray.positionAt(distance)
          this._center = position
        }
      }
    })
  }

  protected disposeProgram(): void {
    this._configurator?.dispose()
    this._orbitControls?.dispose()
    this._cameraStruct.dispose()
    this._renderConfigStruct.dispose()
    this._computeConfig.dispose()
    this._dynamicAttributes.forEach((a) => a.dispose())
    this._staticAttributes.dispose()
  }

  private createRenderPipeline() {
    this._renderUniformGroup.initialize(this._device)
    const pipelineDescriptor: GPURenderPipelineDescriptor = {
      label: 'render pipeline',
      layout: this._device.createPipelineLayout({
        bindGroupLayouts: [this._renderUniformGroup.getLayout(this._device)],
      }),
      vertex: {
        module: this._device.createShaderModule({
          code: starsRenderWGSL,
        }),
        entryPoint: 'vertex',
        buffers: [
          // TODO: a class to make this incrementing shader locations
          {
            stepMode: 'instance',
            attributes: this._staticAttributes.getDescriptor(),
            arrayStride: this._staticAttributes.elementSize,
          } as GPUVertexBufferLayout,
          {
            stepMode: 'instance',
            attributes: this._dynamicAttributes[0].getDescriptor(1),
            arrayStride: this._dynamicAttributes[0].elementSize,
          } as GPUVertexBufferLayout,
        ],
      },
      fragment: {
        module: this._device.createShaderModule({
          code: starsRenderWGSL,
        }),
        entryPoint: 'fragment',
        targets: [
          {
            format: this._gpu.presentationFormat,
            blend: {
              color: {
                operation: 'add',
                srcFactor: 'src-alpha',
                dstFactor: 'one-minus-src-alpha',
              } as GPUBlendComponent,
              alpha: {
                srcFactor: 'one',
                dstFactor: 'one-minus-src-alpha',
                operation: 'add',
              } as GPUBlendComponent,
            } as GPUBlendState,
          } as GPUColorTargetState,
        ],
      },
      primitive: {
        topology: 'triangle-strip',
      },
    }
    this._renderPipeline = this._device.createRenderPipeline(pipelineDescriptor)
  }

  private createComputePipeline() {
    this._computePipeline = new ComputePipeline(
      {
        label: 'compute pipeline',
        compute: {
          module: {
            code: starsComputeWGSL,
          },
          entryPoint: 'main',
        },
      },
      [this._computeBindGroups]
    )
    this._computePipeline.initialize(this._device)
  }

  private createStarsBuffers() {
    this._staticAttributes.createBuffer(
      this._configurator.options.stars,
      this._device,
      this._staticAttributesMapper,
      GPUBufferUsage.VERTEX
    )
    //TODO create on demand
    this._dynamicAttributes.forEach((a) =>
      a.createBuffer(
        this._configurator.options.stars,
        this._device,
        this._dynamicAttributesMapper,
        GPUBufferUsage.VERTEX | GPUBufferUsage.STORAGE
      )
    )
  }

  private optionsChangeHandler = () => {
    this.createStarsBuffers()
  }

  protected update() {
    const { speed, stars } = this._configurator.options
    this._loop.speed = speed
    this.updateUniforms()

    const commandEncoder = this._device.createCommandEncoder()
    const textureView = this._context.getCurrentTexture().createView()
    const renderPassDescriptor: GPURenderPassDescriptor = {
      colorAttachments: [
        {
          view: textureView,
          clearValue: new Color(0, 0, 0, 255).toGPUColor(),
          loadOp: this._configurator.options.autoClear ? 'clear' : 'load',
          storeOp: 'store',
        } as GPURenderPassColorAttachment,
      ],
    }

    {
      const passEncoder = commandEncoder.beginComputePass()
      this._computePipeline.bindTo(passEncoder)
      this._computeBindGroups.activeIndex = this._loop.frame % 2
      passEncoder.dispatchWorkgroups(Math.ceil(stars / 64))
      passEncoder.end()
    }
    {
      const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor)
      passEncoder.setPipeline(this._renderPipeline)
      this._renderUniformGroup.bindTo(passEncoder)
      this._dynamicAttributes[(this._loop.frame + 1) % 2].setVertexBuffer(
        passEncoder,
        1
      )
      this._staticAttributes.setVertexBuffer(passEncoder, 0)
      passEncoder.draw(4, stars, 0, 0)
      passEncoder.end()
    }

    this._device.queue.submit([commandEncoder.finish()])
  }

  private updateUniforms() {
    const { viewProjection, camera } = this._orbitControls
    const { options } = this._configurator

    this._cameraStruct.matrix.setValue(viewProjection)
    this._cameraStruct.up.setValue(camera.getUp())
    this._cameraStruct.right.setValue(camera.getRight())

    this._renderConfigStruct.starSizeMin.setValue(options.starSizeMin)
    this._renderConfigStruct.starSizeMax.setValue(options.starSizeMax)
    this._renderConfigStruct.color1.setValue(
      Color.fromHexString(options.color1)
    )
    this._renderConfigStruct.color2.setValue(
      Color.fromHexString(options.color2)
    )

    this._computeConfig.time.setValue(this._loop.totalSeconds)
    this._computeConfig.frameTime.setValue(this._loop.elapsedSeconds)
    this._computeConfig.gravity.setValue(options.gravity)
    this._computeConfig.center.setValue(this._center)
  }

  private _dynamicAttributesMapper = (buffer: ArrayBuffer): void => {
    const array = new Float32Array(buffer)
    const { options } = this._configurator
    let index = 0
    const galaxyAxis = new Vector3(0, 1, 0)

    for (let star = 0; star < options.stars; star++) {
      const position = randomPointOnSphere(options.radius)
      position.y *= options.height
      const velocity = randomPointOnSphere(0.03)
      const side = galaxyAxis.cross(position.normalize())
      velocity.selfAdd(side.multiplyScalar(0.06))

      // position x, y, z
      array[index++] = position.x
      array[index++] = position.y
      array[index++] = position.z
      array[index++] = 1
      // velocity x, y, z
      array[index++] = velocity.x
      array[index++] = velocity.y
      array[index++] = velocity.z
      array[index++] = 1
    }
  }

  private _staticAttributesMapper = (buffer: ArrayBuffer): void => {
    const array = new Float32Array(buffer)
    const { options } = this._configurator
    let index = 0

    for (let star = 0; star < options.stars; star++) {
      // size 0-1
      array[index++] = Math.random()
    }
  }
}
