import { IBindGroup } from './IBindGroup'

export type VertexState = Omit<GPUVertexState, 'module'> & {
  module: GPUShaderModuleDescriptor
}

export type FragmentState = Omit<GPUFragmentState, 'module'> & {
  module: GPUShaderModuleDescriptor
}

export type PipelineDescriptor = Omit<
  GPURenderPipelineDescriptor,
  'vertex' | 'fragment' | 'layout'
> & {
  vertex: VertexState
  fragment: FragmentState
}

export class RenderPipeline {
  private _descriptor: PipelineDescriptor
  private _bindGroups?: Array<IBindGroup>
  private _gpuPipeline: GPURenderPipeline
  private _device: GPUDevice

  constructor(
    descriptor: PipelineDescriptor,
    bindGroups?: Array<IBindGroup>
  ) {
    this._descriptor = descriptor
    this._bindGroups = bindGroups
  }

  get value(): GPURenderPipeline {
    return this._gpuPipeline
  }

  initialize(device: GPUDevice): void {
    this._device = device
    this._bindGroups?.forEach((g) => g.initialize(device))
    this.createGPUPipeline()
  }

  setFragmentEntrypoint(value: string) {
    this._descriptor.fragment.entryPoint = value
    if (this._device) {
      this.createGPUPipeline()
    }
  }

  bindTo(passEncoder: GPURenderPassEncoder) {
    passEncoder.setPipeline(this.value)
    this._bindGroups?.forEach((bg, index) => bg.bindTo(passEncoder, index))
  }

  private createGPUPipeline() {
    const device = this._device
    const finalDescriptor: GPURenderPipelineDescriptor = {
      ...this._descriptor,
      layout: this._device.createPipelineLayout({
        bindGroupLayouts: this._bindGroups?.map((g) => g.getLayout(device)),
      }),
      vertex: {
        ...this._descriptor.vertex,
        module: device.createShaderModule(this._descriptor.vertex.module),
      },
      fragment: {
        ...this._descriptor.fragment,
        module: device.createShaderModule(this._descriptor.fragment.module),
      },
    }
    this._gpuPipeline = device.createRenderPipeline(finalDescriptor)
  }
}
