Maximise agent buffer size

This commit is contained in:
Andras Schmelczer 2023-05-27 10:49:09 +01:00
parent f1808d5707
commit 5dc943bb91
No known key found for this signature in database
GPG key ID: FC8F2C3D3D1A718C
2 changed files with 12 additions and 8 deletions

View file

@ -26,12 +26,8 @@ export class AgentGenerationPipeline {
public constructor(
private readonly device: GPUDevice,
private readonly commonState: CommonState,
private readonly agentCount: number
private readonly maxAgentCountUpperLimit: number
) {
if (agentCount <= 0 || agentCount != Math.floor(agentCount)) {
throw new Error('Agent count must be a positive integer');
}
this.bindGroupLayout = device.createBindGroupLayout({
entries: [
{
@ -59,7 +55,7 @@ export class AgentGenerationPipeline {
});
this.agentsBuffer = this.device.createBuffer({
size: agentCount * AGENT_SIZE_IN_BYTES,
size: this.maxAgentCount * AGENT_SIZE_IN_BYTES,
usage: GPUBufferUsage.STORAGE,
});
@ -135,6 +131,13 @@ export class AgentGenerationPipeline {
});
}
public get maxAgentCount(): number {
return Math.min(
this.maxAgentCountUpperLimit,
Math.floor(this.device.limits.maxBufferSize / AGENT_SIZE_IN_BYTES)
);
}
public spawnFirstGeneration(): void {
const commandEncoder = this.device.createCommandEncoder();
@ -143,7 +146,7 @@ export class AgentGenerationPipeline {
passEncoder.setPipeline(this.firstGenerationPipeline);
passEncoder.setBindGroup(1, this.bindGroup);
passEncoder.dispatchWorkgroups(
Math.ceil(this.agentCount / AgentGenerationPipeline.WORKGROUP_SIZE)
Math.ceil(this.maxAgentCount / AgentGenerationPipeline.WORKGROUP_SIZE)
);
passEncoder.end();
@ -160,7 +163,7 @@ export class AgentGenerationPipeline {
this.commonState.execute(passEncoder);
passEncoder.setBindGroup(1, this.bindGroup);
passEncoder.dispatchWorkgroups(
Math.ceil(this.agentCount / AgentGenerationPipeline.WORKGROUP_SIZE)
Math.ceil(this.maxAgentCount / AgentGenerationPipeline.WORKGROUP_SIZE)
);
passEncoder.end();

View file

@ -21,6 +21,7 @@ export const initializeGpu = async (): Promise<GPUDevice> => {
const gpuDevice = await adapter.requestDevice({
requiredLimits: {
maxBufferSize: adapter.limits.maxBufferSize,
maxStorageBufferBindingSize: adapter.limits.maxStorageBufferBindingSize,
maxComputeInvocationsPerWorkgroup: adapter.limits.maxComputeInvocationsPerWorkgroup,
maxComputeWorkgroupSizeX: adapter.limits.maxComputeWorkgroupSizeX,
maxComputeWorkgroupSizeY: adapter.limits.maxComputeWorkgroupSizeY,