import { isNil } from 'lodash'

const sizes: Partial<Record<GPUVertexFormat, number>> = {
  float32: Float32Array.BYTES_PER_ELEMENT,
  uint32: Float32Array.BYTES_PER_ELEMENT,
  float32x2: 2 * Float32Array.BYTES_PER_ELEMENT,
  float32x3: 3 * Float32Array.BYTES_PER_ELEMENT,
  float32x4: 4 * Float32Array.BYTES_PER_ELEMENT,
}

export type Mapper = (buffer: ArrayBuffer) => void

export class Attributes {
  private _attributes: GPUVertexAttribute[]
  private _elementSize: number
  // TODO: private
  _buffer?: GPUBuffer

  get descriptor(): GPUVertexAttribute[] {
    return this.getDescriptor()
  }
  get elementSize(): number {
    return this._elementSize
  }
  get buffer(): GPUBuffer {
    return this._buffer
  }

  constructor(attributes: GPUVertexAttribute[]) {
    this._attributes = attributes
    this._elementSize = this._attributes.reduce(
      (sum, current) => sizes[current.format] + sum,
      0
    )
  }

  createBuffer(
    elements: number,
    device: GPUDevice,
    mapper?: Mapper,
    usage: GPUBufferUsageFlags = GPUBufferUsage.VERTEX
  ): GPUBuffer {
    if (this._buffer) {
      this._buffer.destroy()
    }
    const mapAtCreation = !isNil(mapper)
    this._buffer = device.createBuffer({
      size: elements * this._elementSize,
      usage,
      mappedAtCreation: mapAtCreation,
    })
    if (mapAtCreation) {
      mapper(this._buffer.getMappedRange())
      this._buffer.unmap()
    }
    return this._buffer
  }

  writeContent(device: GPUDevice, mapper: Mapper) {
    const data = new ArrayBuffer(this._buffer.size)
    mapper(data)
    device.queue.writeBuffer(this._buffer, 0, data)
  }

  setVertexBuffer(passEncoder: GPURenderPassEncoder, slot = 0) {
    passEncoder.setVertexBuffer(slot, this._buffer)
  }

  getDescriptor(startShaderLocation = 0): GPUVertexAttribute[] {
    return this._attributes.map(({ shaderLocation, ...a }) => ({
      ...a,
      shaderLocation: shaderLocation + startShaderLocation,
    }))
  }

  dispose() {
    this._buffer?.destroy()
  }

  static fromTypes(types: GPUVertexFormat[]): Attributes {
    let offset = 0
    const attributes = types.map((type, index) => {
      const size = sizes[type]
      if (!size) {
        throw new Error(`type ${type} not implemented`)
      }
      const attribute = {
        shaderLocation: index,
        offset,
        format: type,
      } as GPUVertexAttribute
      offset += size
      return attribute
    })
    return new Attributes(attributes)
  }
}
