import { isNil } from 'lodash'
import { UniformField } from './UniformField'
import { UniformStruct } from './UniformStruct'
import { getByteLengthToAlign } from './alignment'

export type UniformFieldBuilder<T> = () => T
export type ArrayItem = UniformField | UniformStruct

export class UniformArrayField<T extends ArrayItem> implements UniformField {
  private _byteStart?: number
  private _byteLength?: number
  readonly length: number
  readonly values: T[]

  constructor(length: number, buildField: UniformFieldBuilder<T>) {
    if (length <= 0) {
      throw new Error('Length must be a positive value.')
    }
    this.length = length
    this.values = new Array(length).fill(null).map(() => buildField())
  }

  get byteStart(): number {
    if (isNil(this._byteStart)) {
      throw new Error('Byte range not set.')
    }
    return this._byteStart
  }

  get byteLength(): number {
    if (isNil(this._byteLength)) {
      throw new Error('Byte range not set.')
    }
    return this._byteLength
  }

  get minAlignment(): number {
    return this.values.reduce(
      (maxAlignment, field) => Math.max(field.minAlignment, maxAlignment),
      0
    )
  }

  setByteRange(startByteIndex: number, bufferAlignment: number) {
    for (const value of this.values) {
      startByteIndex += getByteLengthToAlign(startByteIndex, bufferAlignment)
      value.setByteRange(startByteIndex, bufferAlignment)
      startByteIndex += value.byteLength
    }
    const first = this.values.at(0)
    const last = this.values.at(-1)
    this._byteStart = first.byteStart
    this._byteLength = last.byteStart + last.byteLength - this._byteStart
  }

  setBuffer(buffer: GPUBuffer, device: GPUDevice): void {
    this.values.forEach(v => v.setBuffer(buffer, device))
  }

  at(index: number): T {
    return this.values[index]
  }

  forEach(action: (item: T, index: number, array: T[]) => void) {
    this.values.forEach(action)
  }
}
