diff --git a/src/pipelines/agents/agent-generation/agent-generation-pipeline.ts b/src/pipelines/agents/agent-generation/agent-generation-pipeline.ts index ca2d0ad..ba30736 100644 --- a/src/pipelines/agents/agent-generation/agent-generation-pipeline.ts +++ b/src/pipelines/agents/agent-generation/agent-generation-pipeline.ts @@ -26,12 +26,8 @@ export class AgentGenerationPipeline { public constructor( private readonly device: GPUDevice, private readonly commonState: CommonState, - private readonly agentCount: number + private readonly maxAgentCountUpperLimit: number ) { - if (agentCount <= 0 || agentCount != Math.floor(agentCount)) { - throw new Error('Agent count must be a positive integer'); - } - this.bindGroupLayout = device.createBindGroupLayout({ entries: [ { @@ -59,7 +55,7 @@ export class AgentGenerationPipeline { }); this.agentsBuffer = this.device.createBuffer({ - size: agentCount * AGENT_SIZE_IN_BYTES, + size: this.maxAgentCount * AGENT_SIZE_IN_BYTES, usage: GPUBufferUsage.STORAGE, }); @@ -135,6 +131,13 @@ export class AgentGenerationPipeline { }); } + public get maxAgentCount(): number { + return Math.min( + this.maxAgentCountUpperLimit, + Math.floor(this.device.limits.maxBufferSize / AGENT_SIZE_IN_BYTES) + ); + } + public spawnFirstGeneration(): void { const commandEncoder = this.device.createCommandEncoder(); @@ -143,7 +146,7 @@ export class AgentGenerationPipeline { passEncoder.setPipeline(this.firstGenerationPipeline); passEncoder.setBindGroup(1, this.bindGroup); passEncoder.dispatchWorkgroups( - Math.ceil(this.agentCount / AgentGenerationPipeline.WORKGROUP_SIZE) + Math.ceil(this.maxAgentCount / AgentGenerationPipeline.WORKGROUP_SIZE) ); passEncoder.end(); @@ -160,7 +163,7 @@ export class AgentGenerationPipeline { this.commonState.execute(passEncoder); passEncoder.setBindGroup(1, this.bindGroup); passEncoder.dispatchWorkgroups( - Math.ceil(this.agentCount / AgentGenerationPipeline.WORKGROUP_SIZE) + Math.ceil(this.maxAgentCount / AgentGenerationPipeline.WORKGROUP_SIZE) ); passEncoder.end(); diff --git a/src/utils/graphics/initialize-gpu.ts b/src/utils/graphics/initialize-gpu.ts index 4ca3ec0..47e9e4c 100644 --- a/src/utils/graphics/initialize-gpu.ts +++ b/src/utils/graphics/initialize-gpu.ts @@ -21,6 +21,7 @@ export const initializeGpu = async (): Promise => { const gpuDevice = await adapter.requestDevice({ requiredLimits: { maxBufferSize: adapter.limits.maxBufferSize, + maxStorageBufferBindingSize: adapter.limits.maxStorageBufferBindingSize, maxComputeInvocationsPerWorkgroup: adapter.limits.maxComputeInvocationsPerWorkgroup, maxComputeWorkgroupSizeX: adapter.limits.maxComputeWorkgroupSizeX, maxComputeWorkgroupSizeY: adapter.limits.maxComputeWorkgroupSizeY,