import { Bindable } from './Bindable'
import { BindGroup } from './BindGroup'
import { IBindGroup } from './IBindGroup'
import { TexturePair } from './textures/TexturePair'

export class BindGroupWithTexturePairs implements IBindGroup {
  private _bindGroups: BindGroup[]
  private _texturePairs: TexturePair[]

  constructor(
    items: Array<Bindable | TexturePair>,
    visibility: GPUShaderStageFlags = GPUShaderStage.VERTEX |
      GPUShaderStage.FRAGMENT
  ) {
    this._texturePairs = items.filter((i) =>
      Boolean(i instanceof TexturePair)
    ) as TexturePair[]
    let itemsForEachBindGroup: Bindable[][] = [
      items.map((i) => (i instanceof TexturePair ? undefined : i)),
    ]
    items.forEach((item, index) => {
      if (item instanceof TexturePair) {
        itemsForEachBindGroup = itemsForEachBindGroup
          .map((existingItems) => {
            const copy = existingItems.slice()
            copy[index] = item.source
            return copy
          })
          .concat(
            itemsForEachBindGroup.map((existingItems) => {
              const copy = existingItems.slice()
              copy[index] = item.destination
              return copy
            })
          )
      }
    })
    this._bindGroups = itemsForEachBindGroup.map(
      (items) => new BindGroup(items, visibility)
    )
  }

  getLayout(device: GPUDevice, label?: string): GPUBindGroupLayout {
    return this._bindGroups.at(0).getLayout(device, label)
  }

  initialize(device: GPUDevice): void {
    this._bindGroups.forEach((bg) => bg.initialize(device))
  }

  bindTo(passEncoder: GPURenderPassEncoder | GPUComputePassEncoder, index = 0) {
    const bindGroupIndex = this._texturePairs.reduce(
      (resultIndex, texturePair, index) => {
        return resultIndex + (index + 1) * texturePair.sourceIndex
      },
      0
    )
    this._bindGroups.at(bindGroupIndex).bindTo(passEncoder, index)
  }

  dispose() {
    this._bindGroups.forEach((bg) => bg.dispose())
  }
}
