import { Configurator } from './Configurator'
import { Color } from 'src/core/Color'
import { UniformStruct } from 'src/core/webGPU/Uniforms/UniformStruct'
import { Attributes } from 'src/core/webGPU/Attributes'
import { BindGroup } from 'src/core/webGPU/BindGroup'
import { Colors } from './Colors'
import { WebGPUProgram } from 'src/core/webGPU/WebGPUProgram'
import { AttributesBind } from 'src/core/webGPU/AttributesBind'
import { UniformFloatRGBColorField } from 'src/core/webGPU/Uniforms/UniformValueField'
import particlesComputeWGSL from './shaders/particles.compute.wgsl'
import particlesRenderWGSL from './shaders/particles.render.wgsl'
import { Vector2 } from 'src/core/math/Vector2'
import { Accumulator } from './Accumulator'

class RenderConfigStruct extends UniformStruct {
  aspectRatio = this.addFloatField()
  particleSize = this.addFloatField()
  cameraPosition = this.addVector2Field()
  colors = this.addArrayField(Colors.MAX, () => new UniformFloatRGBColorField())
}

class ComputeConfigStruct extends UniformStruct {
  frameTime = this.addFloatField()
  force = this.addFloatField()
  repulsionForce = this.addFloatField()
  dampening = this.addFloatField()
  rmin = this.addFloatField()
  rmax = this.addFloatField()
  aspectRatio = this.addFloatField()
}

export class ParticleLifeWebGPU extends WebGPUProgram {
  private _configurator = new Configurator()
  private _accumulator = new Accumulator(this._gpu)

  private _renderPipeline: GPURenderPipeline
  private _computePipeline: GPUComputePipeline

  private _renderConfigStruct = new RenderConfigStruct()
  private _renderUniformGroup = new BindGroup([this._renderConfigStruct])

  private _computeConfig = new ComputeConfigStruct()

  private _dynamicAttributes = [
    Attributes.fromTypes([
      'float32x2', // position
      'float32x2', // velocity
    ]),
    Attributes.fromTypes([
      'float32x2', // position
      'float32x2', // velocity
    ]),
  ]

  private _staticAttributes = Attributes.fromTypes([
    'uint32', // color index
  ])

  private _colorInteractionForces = Attributes.fromTypes(['float32'])

  private _currentBindGroup = 0
  private _computeBindGroups = [
    new BindGroup(
      [
        this._computeConfig,
        new AttributesBind(this._dynamicAttributes[0], { readOnly: true }),
        new AttributesBind(this._dynamicAttributes[1], { readOnly: false }),
        new AttributesBind(this._staticAttributes, { readOnly: true }),
        new AttributesBind(this._colorInteractionForces, { readOnly: true }),
      ],
      GPUShaderStage.COMPUTE
    ),
    new BindGroup(
      [
        this._computeConfig,
        new AttributesBind(this._dynamicAttributes[1], { readOnly: true }),
        new AttributesBind(this._dynamicAttributes[0], { readOnly: false }),
        new AttributesBind(this._staticAttributes, { readOnly: true }),
        new AttributesBind(this._colorInteractionForces, { readOnly: true }),
      ],
      GPUShaderStage.COMPUTE
    ),
  ]

  private _cameraPosition = new Vector2()

  async initialize(canvas: HTMLCanvasElement): Promise<void> {
    await super.initialize(canvas)
    this._configurator.onStartOptionsChange = this.optionsChangeHandler
    this._configurator.randomizeForces = this.randomizeForces
    this._accumulator.initialize(this._size)

    this.createRenderPipeline()
    this.createComputePipeline()
    this.createStarsBuffers()

    this._viewport.onAnyDrag.add((e) => {
      this._cameraPosition.selfAdd(
        e.change
          .divide(this._size)
          .selfMultiply(new Vector2(-1 * this._aspectRatio, 1))
      )
    })
    this._viewport.onResize.add((size) => {
      this._accumulator.resize(size)
    })
  }

  protected disposeProgram() {
    this._configurator.dispose()
    this._renderConfigStruct.dispose()
    this._dynamicAttributes.forEach((a) => a.dispose())
    this._staticAttributes.dispose()
    this._colorInteractionForces.dispose()
  }

  private createRenderPipeline() {
    this._renderUniformGroup.initialize(this._device)
    const pipelineDescriptor: GPURenderPipelineDescriptor = {
      label: 'render pipeline',
      layout: this._device.createPipelineLayout({
        bindGroupLayouts: [this._renderUniformGroup.getLayout(this._device)],
      }),
      vertex: {
        module: this._device.createShaderModule({
          code: particlesRenderWGSL,
        }),
        entryPoint: 'vertex',
        buffers: [
          // TODO: a class to make this incrementing shader locations
          {
            stepMode: 'instance',
            attributes: this._staticAttributes.getDescriptor(),
            arrayStride: this._staticAttributes.elementSize,
          } as GPUVertexBufferLayout,
          {
            stepMode: 'instance',
            attributes: this._dynamicAttributes[0].getDescriptor(1),
            arrayStride: this._dynamicAttributes[0].elementSize,
          } as GPUVertexBufferLayout,
        ],
      },
      fragment: {
        module: this._device.createShaderModule({
          code: particlesRenderWGSL,
        }),
        entryPoint: 'fragment',
        targets: [
          {
            format: this._gpu.presentationFormat,
            blend: {
              color: {
                operation: 'add',
                srcFactor: 'src-alpha',
                dstFactor: 'one-minus-src-alpha',
              } as GPUBlendComponent,
              alpha: {
                srcFactor: 'one',
                dstFactor: 'one-minus-src-alpha',
                operation: 'add',
              } as GPUBlendComponent,
            } as GPUBlendState,
          } as GPUColorTargetState,
        ],
      },
      primitive: {
        topology: 'triangle-strip',
      },
    }
    this._renderPipeline = this._device.createRenderPipeline(pipelineDescriptor)
  }

  private createComputePipeline() {
    const pipelineDescriptor: GPUComputePipelineDescriptor = {
      label: 'compute pipeline',
      layout: this._device.createPipelineLayout({
        label: 'computeLayout',
        bindGroupLayouts: [this._computeBindGroups[0].getLayout(this._device)],
      }),
      compute: {
        module: this._device.createShaderModule({
          code: particlesComputeWGSL,
        }),
        entryPoint: 'main',
      },
    }
    this._computePipeline =
      this._device.createComputePipeline(pipelineDescriptor)
  }

  private createStarsBuffers() {
    this._staticAttributes.createBuffer(
      this._configurator.options.particles,
      this._device,
      this._staticAttributesMapper,
      GPUBufferUsage.VERTEX | GPUBufferUsage.STORAGE
    )
    this._dynamicAttributes.forEach((a) =>
      a.createBuffer(
        this._configurator.options.particles,
        this._device,
        this._dynamicAttributesMapper,
        GPUBufferUsage.VERTEX | GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
      )
    )
    this._colorInteractionForces.createBuffer(
      Colors.MAX ** 2,
      this._device,
      this._colorInteractionForcesMapper,
      GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
    )

    this._computeBindGroups.forEach((g) => g.initialize(this._device))
  }

  private optionsChangeHandler = () => {
    this.createStarsBuffers()
  }

  private randomizeForces = () => {
    this._colorInteractionForces.writeContent(
      this._device,
      this._colorInteractionForcesMapper
    )
  }

  protected update() {
    const commandEncoder = this._device.createCommandEncoder()
    this.updateCompute(commandEncoder)
    this.draw(commandEncoder)
    this._device.queue.submit([commandEncoder.finish()])
  }

  private updateCompute(commandEncoder: GPUCommandEncoder) {
    const { speed, particles, stats, updatesPerFrame } =
      this._configurator.options
    this.showStats(stats)
    this._loop.speed = speed
    this.updateUniforms()

    for (let i = 0; i < updatesPerFrame; i++) {
      this._currentBindGroup = (this._currentBindGroup + 1) % 2
      const passEncoder = commandEncoder.beginComputePass()
      passEncoder.setPipeline(this._computePipeline)
      this._computeBindGroups[this._currentBindGroup].bindTo(passEncoder)
      passEncoder.dispatchWorkgroups(Math.ceil(particles / 64))
      passEncoder.end()
    }

    // Copy particle position to camera position
    if (this._configurator.options.followAParticle) {
      this.copyParticlePositionToCameraPosition(commandEncoder)
    }
  }

  private draw(commandEncoder: GPUCommandEncoder) {
    const { particles, autoClear } = this._configurator.options

    const renderPassDescriptor: GPURenderPassDescriptor = autoClear
      ? this.getRenderToScreenPassDescriptor()
      : this._accumulator.getDescriptor()

    const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor)
    passEncoder.setPipeline(this._renderPipeline)
    this._renderUniformGroup.bindTo(passEncoder)
    this._dynamicAttributes[(this._currentBindGroup + 1) % 2].setVertexBuffer(
      passEncoder,
      1
    )
    this._staticAttributes.setVertexBuffer(passEncoder, 0)
    passEncoder.draw(4, particles, 0, 0)
    passEncoder.end()

    if (!autoClear) {
      this._accumulator.render(commandEncoder)
    }
  }

  private updateUniforms() {
    const { options } = this._configurator

    this._renderConfigStruct.aspectRatio.setValue(this._aspectRatio)
    this._renderConfigStruct.particleSize.setValue(options.particleSize)
    if (!this._configurator.options.followAParticle) {
      this._renderConfigStruct.cameraPosition.setValue(this._cameraPosition)
    }
    this._renderConfigStruct.colors.forEach((field, i) => {
      field.setValue(Color.fromHexString(options[i]))
    })

    this._computeConfig.frameTime.setValue(
      this._loop.elapsedSeconds / options.updatesPerFrame
    )
    this._computeConfig.force.setValue(options.force * 0.03)
    this._computeConfig.repulsionForce.setValue(options.repulsionForce * 5)
    this._computeConfig.dampening.setValue(options.dampening)
    this._computeConfig.rmin.setValue(options.rmin)
    this._computeConfig.rmax.setValue(options.rmax)
    this._computeConfig.aspectRatio.setValue(this._aspectRatio)
  }

  private _dynamicAttributesMapper = (buffer: ArrayBuffer): void => {
    const array = new Float32Array(buffer)
    const { options } = this._configurator
    let index = 0

    for (let particle = 0; particle < options.particles; particle++) {
      const positionX = (Math.random() - 0.5) * this._aspectRatio
      const positionY = Math.random() - 0.5
      const velocityAngle = Math.random() * Math.PI * 2
      const speed = Math.random() * 0.01
      const velocityX = Math.cos(velocityAngle) * speed
      const velocityY = Math.sin(velocityAngle) * speed

      // position x, y
      array[index++] = positionX
      array[index++] = positionY
      // velocity x, y
      array[index++] = velocityX
      array[index++] = velocityY
    }
  }

  private _staticAttributesMapper = (buffer: ArrayBuffer): void => {
    const array = new Uint32Array(buffer)
    const { options } = this._configurator
    let index = 0
    const colors = options.colorQuantity

    for (let particle = 0; particle < options.particles; particle++) {
      // color index
      array[index++] = Math.floor(Math.random() * colors)
    }
  }

  private _colorInteractionForcesMapper = (buffer: ArrayBuffer): void => {
    const array = new Float32Array(buffer)
    const length = Colors.MAX ** 2
    for (let i = 0; i < length; i++) {
      array[i] = (Math.random() - 0.5) * 2
    }
  }

  private copyParticlePositionToCameraPosition(
    commandEncoder: GPUCommandEncoder
  ) {
    const sourceBuffer =
      this._dynamicAttributes[(this._currentBindGroup + 1) % 2].buffer
    const particleIndex = 0
    const byteLength = Float32Array.BYTES_PER_ELEMENT * 2
    const byteStartOnSourceBuffer = particleIndex * byteLength * 2
    commandEncoder.copyBufferToBuffer(
      sourceBuffer,
      byteStartOnSourceBuffer,
      this._renderConfigStruct.cameraPosition.buffer,
      this._renderConfigStruct.cameraPosition.byteStart,
      byteLength
    )
  }

  private getRenderToScreenPassDescriptor(): GPURenderPassDescriptor {
    return {
      label: 'Render to screen pass',
      colorAttachments: [
        {
          view: this._context.getCurrentTexture().createView(),
          clearValue: new Color(0, 0, 0, 255).toGPUColor(),
          loadOp: 'clear',
          storeOp: 'store',
        } as GPURenderPassColorAttachment,
      ],
    }
  }
}
