import { Bindable } from './Bindable'
import { IBindGroup } from './IBindGroup'

export class BindGroup implements IBindGroup {
  private _items: Bindable[]
  private _visibility: GPUShaderStageFlags
  private _gpuBindGroup?: GPUBindGroup
  private _device: GPUDevice
  private _needsUpdate = false

  constructor(
    items: Bindable[],
    visibility: GPUShaderStageFlags = GPUShaderStage.VERTEX |
      GPUShaderStage.FRAGMENT
  ) {
    this._items = items
    this._visibility = visibility

    this._items.forEach((i) =>
      i.onUpdated?.add(() => (this._needsUpdate = true))
    )
  }

  getLayout(device: GPUDevice, label?: string): GPUBindGroupLayout {
    return device.createBindGroupLayout(
      this.getLayoutDescriptor(this._visibility, label)
    )
  }

  initialize(device: GPUDevice): void {
    this._device = device
    this._gpuBindGroup = device.createBindGroup({
      layout: this.getLayout(device),
      entries: this.getLayoutEntries(device),
    })
  }

  bindTo(passEncoder: GPURenderPassEncoder | GPUComputePassEncoder, index = 0) {
    if (this._needsUpdate) {
      this.initialize(this._device)
      this._needsUpdate = false
    }
    passEncoder.setBindGroup(index, this._gpuBindGroup)
  }

  dispose() {
    this._items.forEach((i) => i.dispose?.())
  }

  private getLayoutDescriptor(
    visibility: GPUShaderStageFlags,
    label?: string
  ): GPUBindGroupLayoutDescriptor {
    return {
      label,
      entries: this._items.map((item, index) => ({
        binding: index,
        visibility,
        ...item.getLayout(),
      })),
    } as GPUBindGroupLayoutDescriptor
  }

  private getLayoutEntries(device: GPUDevice): GPUBindGroupEntry[] {
    return this._items.map(
      (item, index) =>
        ({
          binding: index,
          resource: item.getOrCreateResource(device),
        } as GPUBindGroupEntry)
    )
  }
}
