import { isNil } from 'lodash'
import { UniformField } from './UniformField'
import {
  UniformFloatField,
  UniformFloatRGBColorField,
  UniformMatrix4Field,
  UniformValueField,
  UniformVector2Field,
  UniformVector3Field,
} from './UniformValueField'
import {
  UniformArrayField,
  UniformFieldBuilder,
  ArrayItem,
} from './UniformArrayField'
import { Bindable } from '../Bindable'
import { getByteLengthToAlign } from './alignment'

export abstract class UniformStruct implements UniformField, Bindable {
  private _byteStart?: number
  private _byteLength?: number
  private readonly _fields = new Array<UniformField>()
  private _buffer?: GPUBuffer
  private _resource?: GPUBindingResource
  protected label?: string

  get byteStart(): number {
    if (isNil(this._byteStart)) {
      throw new Error('Byte range not set.')
    }
    return this._byteStart
  }

  get byteLength(): number {
    if (isNil(this._byteLength)) {
      throw new Error('Byte range not set.')
    }
    return this._byteLength
  }

  get minAlignment(): number {
    return Object.values(this._fields).reduce(
      (maxAlignment, field) => Math.max(field.minAlignment, maxAlignment),
      0
    )
  }

  getLayout(): Partial<GPUBindGroupLayoutEntry> {
    return {
      buffer: {
        type: 'uniform',
      },
    }
  }

  getOrCreateResource(device: GPUDevice): GPUBindingResource {
    return this._resource ?? this.createResource(device)
  }

  getOrCreateBuffer(device: GPUDevice): GPUBuffer {
    if (this._buffer === undefined) {
      this.initializeByteRange()
      this._buffer = device.createBuffer({
        label: this.label,
        size: this._byteLength,
        usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
      })
      Object.values(this._fields).forEach((f) =>
        f.setBuffer(this._buffer, device)
      )
    }
    return this._buffer
  }

  setByteRange(startByteIndex: number, bufferAlignment: number) {
    for (const field of this._fields) {
      field.setByteRange(startByteIndex, bufferAlignment)
      startByteIndex = field.byteStart + field.byteLength
    }
    const first = this._fields.at(0)
    const last = this._fields.at(-1)
    this._byteStart = first.byteStart
    this._byteLength = last.byteStart + last.byteLength - this._byteStart
  }

  setBuffer(buffer: GPUBuffer, device: GPUDevice): void {
    this._fields.forEach((v) => v.setBuffer(buffer, device))
  }

  dispose() {
    this._buffer?.destroy()
  }

  protected addFloatField(): UniformFloatField {
    return this.addValueField(new UniformFloatField())
  }
  protected addVector2Field(): UniformVector2Field {
    return this.addValueField(new UniformVector2Field())
  }
  protected addVector3Field(): UniformVector3Field {
    return this.addValueField(new UniformVector3Field())
  }
  protected addFloatRGBColorField(): UniformFloatRGBColorField {
    return this.addValueField(new UniformFloatRGBColorField())
  }
  protected addMatrix4Field(): UniformMatrix4Field {
    return this.addValueField(new UniformMatrix4Field())
  }

  protected addValueField<T>(
    field: UniformValueField<T>
  ): UniformValueField<T> {
    this._fields.push(field)
    return field
  }

  protected addArrayField<T extends ArrayItem>(
    length: number,
    fieldBuilder: UniformFieldBuilder<T>
  ): UniformArrayField<T> {
    const field = new UniformArrayField<T>(length, fieldBuilder)
    this._fields.push(field)
    return field
  }

  protected initializeByteRange() {
    this.setByteRange(0, this.minAlignment)
    this._byteLength += getByteLengthToAlign(this._byteLength, this.minAlignment)
  }

  private createResource(device: GPUDevice): GPUBindingResource {
    const buffer = this.getOrCreateBuffer(device)
    return (this._resource = {
      buffer,
      size: buffer.size,
    })
  }
}
