import { IBindGroup } from './IBindGroup'

export type ComputeStage =  Omit<GPUProgrammableStage, 'module'> & {
  module: GPUShaderModuleDescriptor
  entryPoint: string
}

export type ComputePipelineDescriptor = Omit<
  GPUComputePipelineDescriptor,
  'layout' | 'compute'
> & {
  compute: ComputeStage
}

export class ComputePipeline {
  private _descriptor: ComputePipelineDescriptor
  private _bindGroups?: Array<IBindGroup>
  private _gpuPipeline: GPUComputePipeline
  private _device: GPUDevice

  constructor(
    descriptor: ComputePipelineDescriptor,
    bindGroups?: Array<IBindGroup>
  ) {
    this._descriptor = descriptor
    this._bindGroups = bindGroups
  }

  get value(): GPUComputePipeline {
    return this._gpuPipeline
  }

  initialize(device: GPUDevice): void {
    this._device = device
    this._bindGroups?.forEach((g) => g.initialize(device))
    this.createGPUPipeline()
  }

  setFragmentEntrypoint(value: string) {
    this._descriptor.compute.entryPoint = value
    if (this._device) {
      this.createGPUPipeline()
    }
  }

  bindTo(passEncoder: GPUComputePassEncoder) {
    passEncoder.setPipeline(this.value)
    this._bindGroups?.forEach((bg, index) => bg.bindTo(passEncoder, index))
  }

  private createGPUPipeline() {
    const device = this._device
    const finalDescriptor: GPUComputePipelineDescriptor = {
      ...this._descriptor,
      layout: this._device.createPipelineLayout({
        bindGroupLayouts: this._bindGroups?.map((g) => g.getLayout(device)),
      }),
      compute: {
        ...this._descriptor.compute,
        module: device.createShaderModule(this._descriptor.compute.module),
      },
    }
    this._gpuPipeline = device.createComputePipeline(finalDescriptor)
  }
}
