This commit is contained in:
Andras Schmelczer 2026-05-13 21:07:10 +01:00
parent 34ac200437
commit 39b0160064
136 changed files with 7144 additions and 1965 deletions

View file

@ -0,0 +1,36 @@
struct Settings {
agentCount: u32,
padding0: u32,
padding1: u32,
padding2: u32,
};
struct Counters {
aliveAgentCount: atomic<u32>,
padding0: atomic<u32>,
padding1: atomic<u32>,
};
@group(1) @binding(0) var<uniform> settings: Settings;
@group(1) @binding(2) var<storage, read_write> counters: Counters;
@group(1) @binding(3) var<storage, read_write> compactedAgents: array<Agent>;
@compute @workgroup_size(64)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) workgroup_count: vec3<u32>
) {
let id = get_id(global_id, workgroup_count);
if id >= settings.agentCount {
return;
}
let agent = agents[id];
if agent.colorIndex < 0.0 {
return;
}
let compactedIndex = atomicAdd(&counters.aliveAgentCount, 1);
compactedAgents[compactedIndex] = agent;
}

View file

@ -5,8 +5,8 @@ struct Settings {
@group(1) @binding(0) var<uniform> settings: Settings;
struct Counters {
evenGenerationAlive: atomic<u32>,
oddGenerationAlive: atomic<u32>,
redAgentsAlive: atomic<u32>,
greenAgentsAlive: atomic<u32>,
};
@group(1) @binding(2) var<storage, read_write> counters: Counters;
@ -23,9 +23,13 @@ fn main(
return;
}
if agents[id].generation % 2 == 0 {
atomicAdd(&counters.evenGenerationAlive, 1);
if agents[id].colorIndex < 0.0 {
return;
}
if agents[id].colorIndex < 0.5 {
atomicAdd(&counters.redAgentsAlive, 1);
} else {
atomicAdd(&counters.oddGenerationAlive, 1);
atomicAdd(&counters.greenAgentsAlive, 1);
}
}

View file

@ -30,5 +30,8 @@ fn main(
randomPosition.xz * state.size,
random.r * 3.14 * 2,
0,
vec2<f32>(-1.0, -1.0),
0.0,
0.0,
);
}

View file

@ -1,27 +1,42 @@
import { vec2 } from 'gl-matrix';
import { getWorkgroupCounts } from '../../../utils/graphics/get-workgroup-counts';
import { smartCompile } from '../../../utils/graphics/smart-compile';
import { CommonState } from '../../common-state/common-state';
import { AGENT_SIZE_IN_BYTES } from './agent';
import compactionShader from './agent-compaction.wgsl?raw';
import countingShader from './agent-counting.wgsl?raw';
import firstGenerationShader from './agent-first-generation.wgsl?raw';
import resizeShader from './agent-resize.wgsl?raw';
import agentSchema from './agent-schema.wgsl?raw';
import { GenerationCounts } from './generation-counts';
export class AgentGenerationPipeline {
private static readonly WORKGROUP_SIZE = 64;
private static readonly UNIFORM_COUNT = 1;
private static readonly UNIFORM_COUNT = 4;
private static readonly COUNTER_COUNT = 3;
private readonly bindGroupLayout: GPUBindGroupLayout;
private readonly compactionBindGroupLayout: GPUBindGroupLayout;
private readonly uniforms: GPUBuffer;
private readonly bindGroup: GPUBindGroup;
private readonly compactionBindGroup: GPUBindGroup;
private readonly firstGenerationPipeline: GPUComputePipeline;
private readonly countingPipeline: GPUComputePipeline;
private readonly resizePipeline: GPUComputePipeline;
private readonly compactionPipeline: GPUComputePipeline;
public readonly agentsBuffer: GPUBuffer;
private readonly compactedAgentsBuffer: GPUBuffer;
public readonly countersBuffer: GPUBuffer;
public readonly countersStagingBuffer: GPUBuffer;
private readonly counterClearValues = new Uint32Array(
AgentGenerationPipeline.COUNTER_COUNT
);
private readonly agentCountUniformValues = new Uint32Array(
AgentGenerationPipeline.UNIFORM_COUNT
);
public constructor(
private readonly device: GPUDevice,
@ -54,9 +69,47 @@ export class AgentGenerationPipeline {
],
});
this.compactionBindGroupLayout = device.createBindGroupLayout({
entries: [
{
binding: 0,
visibility: GPUShaderStage.COMPUTE,
buffer: {
type: 'uniform',
},
},
{
binding: 1,
visibility: GPUShaderStage.COMPUTE,
buffer: {
type: 'storage',
},
},
{
binding: 2,
visibility: GPUShaderStage.COMPUTE,
buffer: {
type: 'storage',
},
},
{
binding: 3,
visibility: GPUShaderStage.COMPUTE,
buffer: {
type: 'storage',
},
},
],
});
this.agentsBuffer = this.device.createBuffer({
size: this.maxAgentCount * AGENT_SIZE_IN_BYTES,
usage: GPUBufferUsage.STORAGE,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
this.compactedAgentsBuffer = this.device.createBuffer({
size: this.maxAgentCount * AGENT_SIZE_IN_BYTES,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
});
this.countersBuffer = this.device.createBuffer({
@ -98,6 +151,36 @@ export class AgentGenerationPipeline {
],
});
this.compactionBindGroup = this.device.createBindGroup({
layout: this.compactionBindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: this.uniforms,
},
},
{
binding: 1,
resource: {
buffer: this.agentsBuffer,
},
},
{
binding: 2,
resource: {
buffer: this.countersBuffer,
},
},
{
binding: 3,
resource: {
buffer: this.compactedAgentsBuffer,
},
},
],
});
this.firstGenerationPipeline = device.createComputePipeline({
layout: device.createPipelineLayout({
bindGroupLayouts: [commonState.bindGroupLayout, this.bindGroupLayout],
@ -122,16 +205,79 @@ export class AgentGenerationPipeline {
entryPoint: 'main',
},
});
this.resizePipeline = device.createComputePipeline({
layout: device.createPipelineLayout({
bindGroupLayouts: [commonState.bindGroupLayout, this.bindGroupLayout],
}),
compute: {
module: smartCompile(device, CommonState.shaderCode, agentSchema, resizeShader),
entryPoint: 'main',
},
});
this.compactionPipeline = device.createComputePipeline({
layout: device.createPipelineLayout({
bindGroupLayouts: [commonState.bindGroupLayout, this.compactionBindGroupLayout],
}),
compute: {
module: smartCompile(
device,
CommonState.shaderCode,
agentSchema,
compactionShader
),
entryPoint: 'main',
},
});
}
public get maxAgentCount(): number {
return Math.min(
this.maxAgentCountUpperLimit,
Number.isFinite(this.maxAgentCountUpperLimit)
? this.maxAgentCountUpperLimit
: Number.POSITIVE_INFINITY,
Math.floor(this.device.limits.maxBufferSize / AGENT_SIZE_IN_BYTES) - 1,
this.device.limits.maxComputeWorkgroupsPerDimension ** 3
);
}
public writeAgents(agentOffset: number, data: Float32Array): void {
this.device.queue.writeBuffer(
this.agentsBuffer,
agentOffset * AGENT_SIZE_IN_BYTES,
data
);
}
public resizeAgents(agentCount: number, scale: vec2): void {
if (agentCount <= 0 || vec2.equals(scale, vec2.fromValues(1, 1))) {
return;
}
this.device.queue.writeBuffer(
this.uniforms,
0,
new Float32Array([scale[0], scale[1], agentCount, 0])
);
const commandEncoder = this.device.createCommandEncoder();
const passEncoder = commandEncoder.beginComputePass();
this.commonState.execute(passEncoder);
passEncoder.setPipeline(this.resizePipeline);
passEncoder.setBindGroup(1, this.bindGroup);
passEncoder.dispatchWorkgroups(
...getWorkgroupCounts(
this.device,
agentCount,
AgentGenerationPipeline.WORKGROUP_SIZE
)
);
passEncoder.end();
this.device.queue.submit([commandEncoder.finish()]);
}
public spawnFirstGeneration(): void {
const commandEncoder = this.device.createCommandEncoder();
@ -152,8 +298,11 @@ export class AgentGenerationPipeline {
}
public async countAgents(agentCount: number): Promise<GenerationCounts> {
this.device.queue.writeBuffer(this.countersBuffer, 0, new Uint32Array([0, 0]));
this.device.queue.writeBuffer(this.uniforms, 0, new Uint32Array([agentCount]));
this.counterClearValues.fill(0);
this.agentCountUniformValues.fill(0);
this.agentCountUniformValues[0] = agentCount;
this.device.queue.writeBuffer(this.countersBuffer, 0, this.counterClearValues);
this.device.queue.writeBuffer(this.uniforms, 0, this.agentCountUniformValues);
const commandEncoder = this.device.createCommandEncoder();
@ -190,10 +339,62 @@ export class AgentGenerationPipeline {
};
}
public async compactAgents(agentCount: number): Promise<number> {
if (agentCount <= 0) {
return 0;
}
this.counterClearValues.fill(0);
this.agentCountUniformValues.fill(0);
this.agentCountUniformValues[0] = agentCount;
this.device.queue.writeBuffer(this.countersBuffer, 0, this.counterClearValues);
this.device.queue.writeBuffer(this.uniforms, 0, this.agentCountUniformValues);
const commandEncoder = this.device.createCommandEncoder();
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(this.compactionPipeline);
this.commonState.execute(passEncoder);
passEncoder.setBindGroup(1, this.compactionBindGroup);
passEncoder.dispatchWorkgroups(
...getWorkgroupCounts(
this.device,
agentCount,
AgentGenerationPipeline.WORKGROUP_SIZE
)
);
passEncoder.end();
commandEncoder.copyBufferToBuffer(
this.compactedAgentsBuffer,
0,
this.agentsBuffer,
0,
agentCount * AGENT_SIZE_IN_BYTES
);
commandEncoder.copyBufferToBuffer(
this.countersBuffer,
0,
this.countersStagingBuffer,
0,
Uint32Array.BYTES_PER_ELEMENT
);
this.device.queue.submit([commandEncoder.finish()]);
await this.countersStagingBuffer.mapAsync(GPUMapMode.READ);
const compactedCount = new Uint32Array(
this.countersStagingBuffer.getMappedRange().slice(0, Uint32Array.BYTES_PER_ELEMENT)
)[0];
this.countersStagingBuffer.unmap();
return compactedCount;
}
public destroy() {
this.uniforms.destroy();
this.countersBuffer.destroy();
this.countersStagingBuffer.destroy();
this.compactedAgentsBuffer.destroy();
this.agentsBuffer.destroy();
}
}

View file

@ -0,0 +1,74 @@
import { describe, expect, it } from 'vitest';
import { AGENT_FLOAT_COUNT, AGENT_SIZE_IN_BYTES } from './agent';
import compactionShader from './agent-compaction.wgsl?raw';
import countingShader from './agent-counting.wgsl?raw';
import firstGenerationShader from './agent-first-generation.wgsl?raw';
import resizeShader from './agent-resize.wgsl?raw';
import agentSchema from './agent-schema.wgsl?raw';
const wgslFloatCountByType: Record<string, number> = {
f32: 1,
'vec2<f32>': 2,
};
const getAgentStructFields = () => {
const match = /struct Agent\s*\{(?<body>[\s\S]*?)\n\}/.exec(agentSchema);
if (!match?.groups?.body) {
throw new Error('Agent struct was not found in agent-schema.wgsl');
}
return match.groups.body
.split('\n')
.map((line) => line.trim().replace(/,$/, ''))
.filter(Boolean)
.map((line) => {
const fieldMatch = /^(?<name>\w+):\s*(?<type>[^,]+)$/.exec(line);
if (!fieldMatch?.groups) {
throw new Error(`Unsupported Agent field syntax: ${line}`);
}
return {
name: fieldMatch.groups.name,
type: fieldMatch.groups.type,
};
});
};
describe('Agent TS/WGSL contract', () => {
it('keeps the TypeScript float count aligned with the WGSL Agent struct', () => {
const fields = getAgentStructFields();
const wgslFloatCount = fields.reduce((sum, field) => {
const count = wgslFloatCountByType[field.type];
if (!count) {
throw new Error(`Unsupported WGSL Agent field type: ${field.type}`);
}
return sum + count;
}, 0);
expect(fields.map((field) => field.name)).toEqual([
'position',
'angle',
'colorIndex',
'targetPosition',
'targetAngle',
'introDelay',
]);
expect(wgslFloatCount).toBe(AGENT_FLOAT_COUNT);
expect(AGENT_SIZE_IN_BYTES).toBe(AGENT_FLOAT_COUNT * Float32Array.BYTES_PER_ELEMENT);
});
it('keeps generation shader workgroup sizes aligned with agent indexing', () => {
[firstGenerationShader, countingShader, resizeShader, compactionShader].forEach(
(shader) => {
expect(shader).toMatch(/@workgroup_size\(64\)/);
}
);
expect(agentSchema).toContain('workgroup_count.x * 64');
expect(agentSchema).toContain('workgroup_count.x * workgroup_count.y * 64');
expect(compactionShader).toContain('let id = get_id(global_id, workgroup_count);');
expect(compactionShader).toContain('if id >= settings.agentCount');
});
});

View file

@ -3,7 +3,11 @@ import { vec2 } from 'gl-matrix';
export interface Agent {
position: vec2;
angle: number;
generation: number;
colorIndex: number;
targetPosition: vec2;
targetAngle: number;
introDelay: number;
}
export const AGENT_SIZE_IN_BYTES = 4 * Float32Array.BYTES_PER_ELEMENT;
export const AGENT_FLOAT_COUNT = 8;
export const AGENT_SIZE_IN_BYTES = AGENT_FLOAT_COUNT * Float32Array.BYTES_PER_ELEMENT;

View file

@ -1,5 +1,7 @@
import { vec2 } from 'gl-matrix';
import {
createCachedFloat32BufferWrite,
writeFloat32BufferIfChanged,
} from '../../utils/graphics/cached-buffer-write';
import { getWorkgroupCounts } from '../../utils/graphics/get-workgroup-counts';
import { smartCompile } from '../../utils/graphics/smart-compile';
import { CommonState } from '../common-state/common-state';
@ -9,14 +11,19 @@ import shader from './agent.wgsl?raw';
export class AgentPipeline {
private static readonly WORKGROUP_SIZE = 64;
private static readonly UNIFORM_COUNT = 19;
private static readonly UNIFORM_COUNT = 8;
private readonly bindGroupLayout: GPUBindGroupLayout;
private readonly pipeline: GPUComputePipeline;
private readonly uniforms: GPUBuffer;
private bindGroup?: GPUBindGroup;
private previousTrailMapIn?: GPUTextureView;
private previousTrailMapOut?: GPUTextureView;
private readonly uniformValues = new Float32Array(AgentPipeline.UNIFORM_COUNT);
private readonly uniformCache = createCachedFloat32BufferWrite(
AgentPipeline.UNIFORM_COUNT
);
private readonly bindGroupsByTexture = new WeakMap<
GPUTextureView,
WeakMap<GPUTextureView, WeakMap<GPUTextureView, GPUBindGroup>>
>();
private agentCount = 0;
@ -45,115 +52,108 @@ export class AgentPipeline {
public setParameters({
deltaTime,
center,
radius,
brushTrailWeight,
moveSpeed,
turnSpeed,
sensorOffsetAngle,
sensorOffsetDistance,
nextGenerationSensorOffsetDistance,
currentGenerationAggression,
nextGenerationAggression,
nextGenerationSpeed,
isNextGenerationOdd,
turnWhenLost,
individualTrailWeight,
infectionProbability,
agentCount,
introProgress,
}: AgentSettings & {
deltaTime: number;
currentGenerationAggression: number;
nextGenerationAggression: number;
nextGenerationSensorOffsetDistance: number;
nextGenerationSpeed: number;
isNextGenerationOdd: number;
center: vec2;
radius: number;
infectionProbability: number;
agentCount: number;
introProgress?: number;
}) {
this.agentCount = agentCount;
this.device.queue.writeBuffer(
this.uniformValues[0] = moveSpeed * deltaTime;
this.uniformValues[1] = turnSpeed * deltaTime;
this.uniformValues[2] = (sensorOffsetAngle * Math.PI) / 180;
this.uniformValues[3] = sensorOffsetDistance;
this.uniformValues[4] = turnWhenLost;
this.uniformValues[5] = individualTrailWeight;
this.uniformValues[6] = agentCount;
this.uniformValues[7] = introProgress ?? 1;
writeFloat32BufferIfChanged(
this.device,
this.uniforms,
0,
new Float32Array([
...center,
radius,
brushTrailWeight,
moveSpeed * deltaTime,
turnSpeed * deltaTime,
(sensorOffsetAngle * Math.PI) / 180,
sensorOffsetDistance,
currentGenerationAggression,
nextGenerationAggression,
nextGenerationSensorOffsetDistance,
nextGenerationSpeed * deltaTime,
isNextGenerationOdd,
turnWhenLost,
individualTrailWeight,
infectionProbability,
agentCount,
])
this.uniformValues,
this.uniformCache
);
}
public execute(
commandEncoder: GPUCommandEncoder,
trailMapIn: GPUTextureView,
trailMapOut: GPUTextureView
trailMapOut: GPUTextureView,
sourceMap: GPUTextureView
) {
this.ensureBindGroupExists(trailMapIn, trailMapOut);
const bindGroup = this.getBindGroup(trailMapIn, trailMapOut, sourceMap);
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(this.pipeline);
this.commonState.execute(passEncoder);
passEncoder.setBindGroup(1, this.bindGroup);
passEncoder.setBindGroup(1, bindGroup);
passEncoder.dispatchWorkgroups(
...getWorkgroupCounts(this.device, this.agentCount, AgentPipeline.WORKGROUP_SIZE)
);
passEncoder.end();
}
private ensureBindGroupExists(trailMapIn: GPUTextureView, trailMapOut: GPUTextureView) {
if (
this.previousTrailMapIn !== trailMapIn ||
this.previousTrailMapOut !== trailMapOut
) {
this.bindGroup = this.device.createBindGroup({
layout: this.bindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: this.uniforms,
},
},
{
binding: 1,
resource: {
buffer: this.agentsBuffer,
},
},
{
binding: 2,
resource: trailMapIn,
},
{
binding: 3,
resource: trailMapOut,
},
],
});
this.previousTrailMapIn = trailMapIn;
this.previousTrailMapOut = trailMapOut;
private getBindGroup(
trailMapIn: GPUTextureView,
trailMapOut: GPUTextureView,
sourceMap: GPUTextureView
): GPUBindGroup {
let outputCache = this.bindGroupsByTexture.get(trailMapIn);
if (!outputCache) {
outputCache = new WeakMap<GPUTextureView, WeakMap<GPUTextureView, GPUBindGroup>>();
this.bindGroupsByTexture.set(trailMapIn, outputCache);
}
let sourceCache = outputCache.get(trailMapOut);
if (!sourceCache) {
sourceCache = new WeakMap<GPUTextureView, GPUBindGroup>();
outputCache.set(trailMapOut, sourceCache);
}
const cached = sourceCache.get(sourceMap);
if (cached) {
return cached;
}
const bindGroup = this.device.createBindGroup({
layout: this.bindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: this.uniforms,
},
},
{
binding: 1,
resource: {
buffer: this.agentsBuffer,
},
},
{
binding: 2,
resource: trailMapIn,
},
{
binding: 3,
resource: trailMapOut,
},
{
binding: 4,
resource: sourceMap,
},
],
});
sourceCache.set(sourceMap, bindGroup);
return bindGroup;
}
public destroy() {
@ -191,6 +191,13 @@ export class AgentPipeline {
format: 'rgba16float',
},
},
{
binding: 4,
visibility: GPUShaderStage.COMPUTE,
texture: {
sampleType: 'float',
},
},
],
};
}

View file

@ -1,11 +1,8 @@
export interface AgentSettings {
brushTrailWeight: number;
moveSpeed: number;
turnSpeed: number;
sensorOffsetAngle: number;
sensorOffsetDistance: number;
turnWhenLost: number;
individualTrailWeight: number;
currentGenerationAggression: number;
nextGenerationAggression: number;
}

View file

@ -1,37 +1,18 @@
struct Settings {
center: vec2<f32>,
radius: f32,
brushTrailWeight: f32,
currentGenerationMoveRate: f32,
moveRate: f32,
turnRate: f32,
sensorAngle: f32,
sensorOffset: f32,
currentGenerationAggression: f32,
nextGenerationAggression: f32,
nextGenerationSensorOffsetDistance: f32,
nextGenerationMoveRate: f32,
isNextGenerationOdd: f32,
turnWhenLost: f32,
individualTrailWeight: f32,
infectionProbability: f32,
agentCount: f32 // might be smaller than the length of the agents array
agentCount: f32,
introProgress: f32,
};
@group(1) @binding(0) var<uniform> settings: Settings;
// even generation's trail -> red channel
// odd generation's trail -> green channel
// unused -> blue channel
// brush -> alpha channel
@group(1) @binding(2) var trailMapIn: texture_2d<f32>;
@group(1) @binding(3) var trailMapOut: texture_storage_2d<rgba16float, write>;
@group(1) @binding(4) var sourceMap: texture_2d<f32>;
@compute @workgroup_size(64)
fn main(
@ -45,90 +26,125 @@ fn main(
}
var agent = agents[id];
if agent.colorIndex < 0.0 {
return;
}
let hasIntroTarget =
settings.introProgress < 0.999 &&
agent.targetPosition.x >= 0.0 &&
agent.targetPosition.y >= 0.0;
if hasIntroTarget && settings.introProgress < agent.introDelay {
return;
}
let random = textureSampleLevel(
noise,
noiseSampler,
vec2(
f32(id) % 23647 / 2000,
state.time % 3243 / 2000
),
fract(vec2(f32(id) * 0.7548777, state.time * 0.00017 + f32(id) * 0.5698403)),
0
);
let isFromCurrentGeneration = abs(agent.generation - settings.isNextGenerationOdd);
let isFromNextGeneration = 1.0 - isFromCurrentGeneration;
let isFromOddGeneration = agent.generation % 2;
let forwardSensor = sensor_position(agent.position, agent.angle, settings.sensorOffset, 0);
let leftSensor = sensor_position(agent.position, agent.angle, settings.sensorOffset, settings.sensorAngle);
let rightSensor = sensor_position(agent.position, agent.angle, settings.sensorOffset, -settings.sensorAngle);
let sensorOffset = mix(settings.sensorOffset, settings.nextGenerationSensorOffsetDistance, isFromNextGeneration);
let moveRate = mix(settings.currentGenerationMoveRate, settings.nextGenerationMoveRate, isFromNextGeneration);
let brushWeight = mix(settings.brushTrailWeight, 0, isFromNextGeneration);
let trailForward = sense(agent.position, agent.angle, sensorOffset, 0);
let trailLeft = sense(agent.position, agent.angle, sensorOffset, settings.sensorAngle);
let trailRight = sense(agent.position, agent.angle, sensorOffset, -settings.sensorAngle);
let trailForward = textureLoad(trailMapIn, forwardSensor, 0);
let trailLeft = textureLoad(trailMapIn, leftSensor, 0);
let trailRight = textureLoad(trailMapIn, rightSensor, 0);
let sourceForwardSample = textureLoad(sourceMap, forwardSensor, 0);
let sourceLeftSample = textureLoad(sourceMap, leftSensor, 0);
let sourceRightSample = textureLoad(sourceMap, rightSensor, 0);
var weightForward = brushWeight * trailForward.a;
var weightLeft = brushWeight * trailLeft.a;
var weightRight = brushWeight * trailRight.a;
let channelMask = get_channel_mask(agent.colorIndex);
let friendForward = dot(trailForward.rgb, channelMask);
let friendLeft = dot(trailLeft.rgb, channelMask);
let friendRight = dot(trailRight.rgb, channelMask);
let agression = mix(settings.currentGenerationAggression, settings.nextGenerationAggression, isFromNextGeneration) + weightForward;
let sourceForward = dot(sourceForwardSample.rgb, channelMask);
let sourceLeft = dot(sourceLeftSample.rgb, channelMask);
let sourceRight = dot(sourceRightSample.rgb, channelMask);
weightForward += mix(trailForward.r + agression * trailForward.g, trailForward.g + agression * trailForward.r, isFromOddGeneration);
weightLeft += mix(trailLeft.r + agression * trailLeft.g, trailLeft.g + agression * trailLeft.r, isFromOddGeneration);
weightRight += mix(trailRight.r + agression * trailRight.g, trailRight.g + agression * trailRight.r, isFromOddGeneration);
var rotation: f32;
let weightForward = friendForward + sourceForward * 24.0;
let weightLeft = friendLeft + sourceLeft * 24.0;
let weightRight = friendRight + sourceRight * 24.0;
var rotation = (random.r - 0.5) * settings.turnWhenLost;
if weightForward >= weightLeft && weightForward >= weightRight {
rotation = 0;
rotation = rotation * 0.25;
} else {
rotation = sign(weightLeft - weightRight) * settings.turnRate;
rotation += sign(weightLeft - weightRight) * settings.turnRate;
}
let nextPosition = clamp(
agent.position + vec2(cos(agent.angle), sin(agent.angle)) * moveRate,
vec2<f32>(0, 0),
state.size
);
if nextPosition.x == 0 || nextPosition.x == state.size.x || nextPosition.y == 0 || nextPosition.y == state.size.y {
rotation = 3.14159265359 + random.a - 0.5;
let sourceAtAgent = textureLoad(sourceMap, vec2<i32>(agent.position), 0);
let sourceAtAgentStrength = clamp(dot(sourceAtAgent.rgb, channelMask), 0.0, 1.0);
var moveRate = settings.moveRate * mix(1.0, 0.08, sourceAtAgentStrength);
var introTargetOffset = vec2<f32>(0.0, 0.0);
var introTargetDistance = 0.0;
if hasIntroTarget {
introTargetOffset = agent.targetPosition - agent.position;
introTargetDistance = length(introTargetOffset);
let targetAngle = atan2(introTargetOffset.y, introTargetOffset.x);
let nearTitle = 1.0 - smoothstep(4.0, max(28.0, settings.sensorOffset * 0.75), introTargetDistance);
let desiredAngle = mix(targetAngle, agent.targetAngle, nearTitle * 0.2);
let introTurn = angle_delta(agent.angle, desiredAngle);
rotation = clamp(introTurn, -settings.turnRate * 3.4, settings.turnRate * 3.4)
+ (random.g - 0.5) * settings.turnWhenLost * 0.18;
moveRate = min(settings.moveRate * mix(2.65, 0.01, nearTitle), introTargetDistance);
}
var trail = vec4<f32>(settings.individualTrailWeight, 0, 0, 0);
if isFromOddGeneration == 1.0 {
trail = vec4<f32>(0, settings.individualTrailWeight, 0, 0);
}
var trailBelow = textureLoad(trailMapIn, vec2<i32>(nextPosition), 0);
agent.angle += rotation;
trailBelow += trail;
if settings.radius > 0 && length(settings.center - agent.position) < settings.radius {
agent.generation = settings.isNextGenerationOdd;
// clear trail map below so the agent won't die immediately
// trailBelow.r = (1 - settings.isNextGenerationOdd) * (trailBelow.r + trailBelow.g);
// trailBelow.g = settings.isNextGenerationOdd * (trailBelow.r + trailBelow.g);
} else {
let relativeWeight = mix(trailBelow.g - trailBelow.r, trailBelow.r - trailBelow.g, isFromOddGeneration);
if (relativeWeight > 0 && (
(isFromCurrentGeneration == 1.0 && trailBelow.a == 0 && random.b < settings.infectionProbability)
|| (isFromCurrentGeneration == 0.0 && trailBelow.a > 0)
)) || (trailBelow.a > 0 && isFromCurrentGeneration == 0.0){
// trailBelow.r = isFromOddGeneration * (trailBelow.r + trailBelow.g);
// trailBelow.g = (1 - isFromOddGeneration) * (trailBelow.r + trailBelow.g);
agent.generation = (agent.generation + 1) % 2;
var step = vec2(cos(agent.angle), sin(agent.angle)) * moveRate;
if hasIntroTarget {
step = vec2<f32>(0.0, 0.0);
if introTargetDistance > 0.5 {
step = introTargetOffset / introTargetDistance * moveRate;
}
}
textureStore(trailMapOut, vec2<i32>(nextPosition), trailBelow);
let maxPosition = state.size - vec2<f32>(1.0, 1.0);
let nextPosition = clamp(agent.position + step, vec2<f32>(0, 0), maxPosition);
if nextPosition.x == 0 || nextPosition.x == maxPosition.x || nextPosition.y == 0 || nextPosition.y == maxPosition.y {
rotation = 3.14159265359 + random.a - 0.5;
}
let sourceBelow = textureLoad(sourceMap, vec2<i32>(nextPosition), 0);
let sourceBelowStrength = dot(sourceBelow.rgb, channelMask);
let trailWeight = settings.individualTrailWeight * (1.0 + sourceBelowStrength * 16.0);
var trailBelow = textureLoad(trailMapIn, vec2<i32>(nextPosition), 0);
trailBelow = vec4<f32>(
trailBelow.rgb + channelMask * trailWeight,
max(trailBelow.a, 0.0)
);
agent.angle += rotation;
agent.position = nextPosition;
textureStore(trailMapOut, vec2<i32>(nextPosition), trailBelow);
agents[id] = agent;
}
fn sense(agentPosition: vec2<f32>, agentAngle: f32, sensorOffset: f32, sensorOffsetAngle: f32) -> vec4<f32> {
fn sensor_position(agentPosition: vec2<f32>, agentAngle: f32, sensorOffset: f32, sensorOffsetAngle: f32) -> vec2<i32> {
let sensorAngle = agentAngle + sensorOffsetAngle;
let sensorPosition = vec2<i32>(agentPosition + vec2(cos(sensorAngle), sin(sensorAngle)) * sensorOffset);
return textureLoad(trailMapIn, sensorPosition, 0);
return vec2<i32>(clamp(
agentPosition + vec2(cos(sensorAngle), sin(sensorAngle)) * sensorOffset,
vec2<f32>(0, 0),
state.size - vec2<f32>(1, 1)
));
}
fn get_channel_mask(colorIndex: f32) -> vec3<f32> {
if colorIndex < 0.5 {
return vec3<f32>(1, 0, 0);
}
if colorIndex < 1.5 {
return vec3<f32>(0, 1, 0);
}
return vec3<f32>(0, 0, 1);
}
fn angle_delta(sourceAngle: f32, targetAngle: f32) -> f32 {
return atan2(sin(targetAngle - sourceAngle), cos(targetAngle - sourceAngle));
}

View file

@ -1,14 +1,24 @@
import { vec2 } from 'gl-matrix';
import { appConfig } from '../../config';
import { clamp } from '../../utils/clamp';
import {
createCachedFloat32BufferWrite,
writeFloat32BufferIfChanged,
} from '../../utils/graphics/cached-buffer-write';
import { smartCompile } from '../../utils/graphics/smart-compile';
import { CommonState } from '../common-state/common-state';
import { BrushSettings } from './brush-settings';
import shader from './brush.wgsl?raw';
interface LineSegment {
from: vec2;
to: vec2;
}
export class BrushPipeline {
private static readonly UNIFORM_COUNT = 2;
private static readonly MAX_LINE_COUNT = 20;
private static readonly UNIFORM_COUNT = 8;
private static readonly MAX_LINE_COUNT = appConfig.pipelines.brush.maxLineCount;
private static readonly VERTICES_PER_LINE_SEGMENT = 6;
private static readonly ATTRIBUTES_PER_LINE_SEGMENT = 6;
@ -16,10 +26,20 @@ export class BrushPipeline {
private readonly bindGroup: GPUBindGroup;
private readonly pipeline: GPURenderPipeline;
private readonly uniforms: GPUBuffer;
private readonly uniformValues = new Float32Array(BrushPipeline.UNIFORM_COUNT);
private readonly uniformCache = createCachedFloat32BufferWrite(
BrushPipeline.UNIFORM_COUNT
);
private readonly vertexBuffer: GPUBuffer;
private readonly vertexUploadData = new Float32Array(
BrushPipeline.MAX_LINE_COUNT *
BrushPipeline.VERTICES_PER_LINE_SEGMENT *
BrushPipeline.ATTRIBUTES_PER_LINE_SEGMENT
);
private linePoints: Array<vec2> = [];
private actualPoints: Array<vec2> = [];
private lineSegments: Array<LineSegment> = [];
private actualSegments: Array<LineSegment> = [];
public constructor(
private readonly device: GPUDevice,
@ -72,18 +92,6 @@ export class BrushPipeline {
targets: [
{
format: 'rgba16float',
blend: {
color: {
operation: 'add',
srcFactor: 'zero',
dstFactor: 'one',
},
alpha: {
operation: 'max',
srcFactor: 'one',
dstFactor: 'one',
},
},
},
],
},
@ -111,112 +119,188 @@ export class BrushPipeline {
}
public addSwipe(position: vec2) {
this.linePoints.push(position);
const previousPosition = this.linePoints[this.linePoints.length - 1] ?? position;
this.addSwipeSegment(previousPosition, position);
this.linePoints.push(vec2.clone(position));
}
public addSwipeSegment(from: vec2, to: vec2) {
this.lineSegments.push({
from: vec2.clone(from),
to: vec2.clone(to),
});
}
public clearSwipes() {
this.linePoints.length = 0;
this.lineSegments.length = 0;
this.actualSegments.length = 0;
}
public setParameters({ brushSize, brushSizeVariation }: BrushSettings) {
this.device.queue.writeBuffer(
public setParameters({
brushSize,
brushSizeVariation,
selectedColorIndex,
isErasing,
}: BrushSettings & { selectedColorIndex: number; isErasing: boolean }) {
this.uniformValues[0] = brushSize / 2;
this.uniformValues[1] = Math.floor((brushSize / 2) * brushSizeVariation);
this.uniformValues[2] = 0;
this.uniformValues[3] = 0;
this.uniformValues[4] = !isErasing && selectedColorIndex === 0 ? 1 : 0;
this.uniformValues[5] = !isErasing && selectedColorIndex === 1 ? 1 : 0;
this.uniformValues[6] = !isErasing && selectedColorIndex === 2 ? 1 : 0;
this.uniformValues[7] = isErasing ? 0 : 1;
writeFloat32BufferIfChanged(
this.device,
this.uniforms,
0,
new Float32Array([brushSize / 2, Math.floor((brushSize / 2) * brushSizeVariation)])
this.uniformValues,
this.uniformCache
);
this.actualPoints = this.linePoints.slice();
this.linePoints.splice(0, this.linePoints.length - 1);
this.actualSegments = this.lineSegments.slice();
this.lineSegments.length = 0;
if (this.actualPoints.length === 0) {
if (this.actualSegments.length === 0) {
return;
}
if (this.actualPoints.length === 1) {
this.actualPoints.push(this.actualPoints[0]); // allow single point swipes
if (this.actualSegments.length > BrushPipeline.MAX_LINE_COUNT) {
this.actualSegments = BrushPipeline.subsampleSegments(this.actualSegments);
}
if (this.actualPoints.length > BrushPipeline.MAX_LINE_COUNT + 1) {
this.actualPoints = BrushPipeline.subsampleLinePoints(this.actualPoints);
const lineCount = this.lineCount;
let floatOffset = 0;
for (let i = 0; i < lineCount; i++) {
const segment = this.actualSegments[i];
floatOffset = this.writeSegmentVertices(
this.vertexUploadData,
floatOffset,
segment.from,
segment.to,
brushSize / 2
);
}
this.device.queue.writeBuffer(
this.vertexBuffer,
0,
new Float32Array(
new Array(this.lineCount).fill(0).flatMap((_, i) => {
const from = this.actualPoints[i];
const to = this.actualPoints[i + 1];
const [a, b, c, d] = this.getSegmentBoundingBox(from, to, brushSize / 2);
return [a, b, c, b, c, d].flatMap((v) => [...v, ...from, ...to]);
})
)
this.vertexUploadData,
0,
floatOffset
);
}
private static subsampleLinePoints(points: Array<vec2>): Array<vec2> {
const lines = [];
for (let i = 0; i < points.length - 2; i++) {
lines.push({
from: points[i],
to: points[i + 1],
length: vec2.dist(points[i], points[i + 1]),
});
private static subsampleSegments(segments: Array<LineSegment>): Array<LineSegment> {
if (segments.length <= BrushPipeline.MAX_LINE_COUNT) {
return segments;
}
const sumLength = lines.reduce((sum, line) => sum + line.length, 0);
let currentLineIndex = 0;
let lineLengthSoFar = 0;
const result: Array<vec2> = [points[0]];
for (let i = 1; i < BrushPipeline.MAX_LINE_COUNT; i++) {
const t = (i * sumLength) / (BrushPipeline.MAX_LINE_COUNT + 1);
while (lineLengthSoFar + lines[currentLineIndex].length < t) {
lineLengthSoFar += lines[currentLineIndex].length;
currentLineIndex++;
}
const line = lines[currentLineIndex];
const position = vec2.lerp(
vec2.create(),
line.from,
line.to,
(t - lineLengthSoFar) / line.length
const result: Array<LineSegment> = [];
for (let i = 0; i < BrushPipeline.MAX_LINE_COUNT; i++) {
const index = Math.round(
(i * (segments.length - 1)) / (BrushPipeline.MAX_LINE_COUNT - 1)
);
result.push(position);
result.push(segments[index]);
}
result.push(points[points.length - 1]);
return result;
}
private getSegmentBoundingBox(from: vec2, to: vec2, width: number): Array<vec2> {
let dir = vec2.sub(vec2.create(), to, from);
vec2.normalize(dir, dir);
private writeSegmentVertices(
target: Float32Array,
offset: number,
from: vec2,
to: vec2,
width: number
): number {
const dx = to[0] - from[0];
const dy = to[1] - from[1];
const length = Math.hypot(dx, dy);
const directionX = length > 0 ? dx / length : 1;
const directionY = length > 0 ? dy / length : 0;
const scaledDirectionX = directionX * width;
const scaledDirectionY = directionY * width;
const perpendicularX = directionY * width;
const perpendicularY = -directionX * width;
if (vec2.len(dir) === 0) {
dir = vec2.fromValues(1, 0); // allow single point swipes
}
const startX = from[0] - scaledDirectionX;
const startY = from[1] - scaledDirectionY;
const endX = to[0] + scaledDirectionX;
const endY = to[1] + scaledDirectionY;
const perp = vec2.fromValues(dir[1], -dir[0]);
offset = this.writeVertex(
target,
offset,
startX + perpendicularX,
startY + perpendicularY,
from,
to
);
offset = this.writeVertex(
target,
offset,
startX - perpendicularX,
startY - perpendicularY,
from,
to
);
offset = this.writeVertex(
target,
offset,
endX + perpendicularX,
endY + perpendicularY,
from,
to
);
offset = this.writeVertex(
target,
offset,
startX - perpendicularX,
startY - perpendicularY,
from,
to
);
offset = this.writeVertex(
target,
offset,
endX + perpendicularX,
endY + perpendicularY,
from,
to
);
return this.writeVertex(
target,
offset,
endX - perpendicularX,
endY - perpendicularY,
from,
to
);
}
vec2.scale(dir, dir, width);
vec2.scale(perp, perp, width);
const offsetStart = vec2.sub(vec2.create(), from, dir);
const offsetEnd = vec2.add(vec2.create(), to, dir);
return [
vec2.add(vec2.create(), offsetStart, perp),
vec2.sub(vec2.create(), offsetStart, perp),
vec2.add(vec2.create(), offsetEnd, perp),
vec2.sub(vec2.create(), offsetEnd, perp),
];
private writeVertex(
target: Float32Array,
offset: number,
screenX: number,
screenY: number,
from: vec2,
to: vec2
): number {
target[offset++] = screenX;
target[offset++] = screenY;
target[offset++] = from[0];
target[offset++] = from[1];
target[offset++] = to[0];
target[offset++] = to[1];
return offset;
}
public execute(commandEncoder: GPUCommandEncoder, trailMapOut: GPUTextureView) {
if (this.lineCount === 0) {
return;
}
const renderPassDescriptor: GPURenderPassDescriptor = {
colorAttachments: [
{
@ -256,6 +340,6 @@ export class BrushPipeline {
}
private get lineCount() {
return clamp(this.actualPoints.length - 1, 0, BrushPipeline.MAX_LINE_COUNT);
return clamp(this.actualSegments.length, 0, BrushPipeline.MAX_LINE_COUNT);
}
}

View file

@ -1,4 +1,6 @@
export interface BrushSettings {
brushSize: number;
eraserSize: number;
mirrorSegmentCount: number;
brushSizeVariation: number;
}

View file

@ -1,6 +1,9 @@
struct Settings {
brushSize: f32,
brushSizeVariation: f32
brushSizeVariation: f32,
padding0: f32,
padding1: f32,
brushValue: vec4<f32>,
};
@group(1) @binding(0) var<uniform> settings: Settings;
@ -19,7 +22,7 @@ fn vertex(
@location(2) @interpolate(flat) end: vec2<f32>
) -> VertexOutput {
let uv = screenPosition / state.size;
let position = uv * 2.0 - 1.0;
let position = vec2(uv.x * 2.0 - 1.0, 1.0 - uv.y * 2.0);
return VertexOutput(vec4(position, 0.0, 1.0), screenPosition, start, end);
}
@ -29,20 +32,34 @@ fn fragment(
@location(1) start: vec2<f32>,
@location(2) end: vec2<f32>
) -> @location(0) vec4<f32> {
var distance = distanceFromLine(screenPosition, start, end);
let noise = textureSample(noise, noiseSampler, screenPosition / state.size / 50);
distance += noise.r * settings.brushSizeVariation;
let distance = distanceFromLine(screenPosition, start, end);
let coarseNoise = textureSample(noise, noiseSampler, fract(screenPosition / 160.0)).r;
let grainNoise = textureSample(
noise,
noiseSampler,
fract(screenPosition / 22.0 + vec2(0.31, 0.67))
).r;
let radius = settings.brushSize + (coarseNoise - 0.5) * settings.brushSizeVariation * 2.0;
let feather = max(1.0, settings.brushSize * 0.22);
let edge = 1.0 - smoothstep(radius - feather, radius + feather, distance);
let strength = edge * mix(0.45, 1.0, grainNoise);
if(distance > settings.brushSize) {
if(strength < 0.02) {
discard;
}
return vec4(0, 0, 0, 1);
return vec4(settings.brushValue.rgb * strength, settings.brushValue.a * strength);
}
fn distanceFromLine(position: vec2<f32>, start: vec2<f32>, end: vec2<f32>) -> f32 {
let pa = position - start;
let direction = end - start;
let q = clamp(dot(pa, direction) / dot(direction, direction), 0, 1);
let denominator = dot(direction, direction);
if denominator <= 0.0001 {
return length(pa);
}
let q = clamp(dot(pa, direction) / denominator, 0, 1);
return length(pa - direction * q);
}

View file

@ -8,7 +8,7 @@ import { generateNoise } from '../../utils/graphics/noise';
export class CommonState {
private static readonly UNIFORM_COUNT = 4;
private static readonly NOISE_TEXTURE_SIZE = 1024;
private static readonly NOISE_TEXTURE_SIZE = 2048;
private readonly uniforms: GPUBuffer;
private readonly uniformValues = new Float32Array(CommonState.UNIFORM_COUNT);

View file

@ -1,19 +1,28 @@
import { vec2 } from 'gl-matrix';
import {
createCachedFloat32BufferWrite,
writeFloat32BufferIfChanged,
} from '../../utils/graphics/cached-buffer-write';
import { smartCompile } from '../../utils/graphics/smart-compile';
import shader from './copy.wgsl?raw';
export class CopyPipeline {
private static readonly UNIFORM_COUNT = 2;
private static readonly DEFAULT_SCALE = vec2.fromValues(1, 1);
private readonly bindGroupLayout: GPUBindGroupLayout;
private readonly pipeline: GPURenderPipeline;
private readonly uniforms: GPUBuffer;
private readonly uniformValues = new Float32Array(CopyPipeline.UNIFORM_COUNT);
private readonly uniformCache = createCachedFloat32BufferWrite(
CopyPipeline.UNIFORM_COUNT
);
private readonly sampler: GPUSampler;
private readonly vertexBuffer: GPUBuffer;
private bindGroup?: GPUBindGroup;
private previousTrailMapIn?: GPUTextureView;
private readonly bindGroupsByInput = new WeakMap<GPUTextureView, GPUBindGroup>();
public constructor(private readonly device: GPUDevice) {
this.bindGroupLayout = device.createBindGroupLayout(CopyPipeline.bindGroupLayout);
@ -23,6 +32,11 @@ export class CopyPipeline {
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
this.sampler = this.device.createSampler({
magFilter: 'linear',
minFilter: 'linear',
});
this.vertexBuffer = device.createBuffer({
size: 2 * 4 * Float32Array.BYTES_PER_ELEMENT, // 4 x vec2<f32>
usage: GPUBufferUsage.VERTEX,
@ -79,9 +93,16 @@ export class CopyPipeline {
commandEncoder: GPUCommandEncoder,
trailMapIn: GPUTextureView,
trailMapOut: GPUTextureView,
scale: vec2 = vec2.fromValues(1, 1)
scale: vec2 = CopyPipeline.DEFAULT_SCALE
) {
this.device.queue.writeBuffer(this.uniforms, 0, new Float32Array(scale));
this.uniformValues[0] = scale[0];
this.uniformValues[1] = scale[1];
writeFloat32BufferIfChanged(
this.device,
this.uniforms,
this.uniformValues,
this.uniformCache
);
const renderPassDescriptor: GPURenderPassDescriptor = {
colorAttachments: [
@ -93,10 +114,10 @@ export class CopyPipeline {
],
};
this.ensureBindGroupExists(trailMapIn);
const bindGroup = this.getBindGroup(trailMapIn);
const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor);
passEncoder.setPipeline(this.pipeline);
passEncoder.setBindGroup(0, this.bindGroup);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.setVertexBuffer(0, this.vertexBuffer);
passEncoder.draw(4, 1);
passEncoder.end();
@ -104,35 +125,37 @@ export class CopyPipeline {
public destroy() {
this.vertexBuffer.destroy();
this.uniforms.destroy();
}
private ensureBindGroupExists(trailMapIn: GPUTextureView) {
if (this.previousTrailMapIn !== trailMapIn) {
this.bindGroup = this.device.createBindGroup({
layout: this.bindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: this.uniforms,
},
},
{
binding: 1,
resource: this.device.createSampler({
magFilter: 'linear',
minFilter: 'linear',
}),
},
{
binding: 2,
resource: trailMapIn,
},
],
});
this.previousTrailMapIn = trailMapIn;
private getBindGroup(trailMapIn: GPUTextureView): GPUBindGroup {
const cached = this.bindGroupsByInput.get(trailMapIn);
if (cached) {
return cached;
}
const bindGroup = this.device.createBindGroup({
layout: this.bindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: this.uniforms,
},
},
{
binding: 1,
resource: this.sampler,
},
{
binding: 2,
resource: trailMapIn,
},
],
});
this.bindGroupsByInput.set(trailMapIn, bindGroup);
return bindGroup;
}
private static get bindGroupLayout(): GPUBindGroupLayoutDescriptor {

View file

@ -0,0 +1,29 @@
import { describe, expect, it } from 'vitest';
import {
getSafeInverseDiffusionRate,
setDiffusionUniformValues,
} from './diffusion-pipeline';
describe('diffusion pipeline parameters', () => {
it('keeps zero diffusion rates finite before writing shader uniforms', () => {
const uniformValues = new Float32Array(4);
setDiffusionUniformValues(uniformValues, {
decayRateBrush: 900,
decayRateTrails: 970,
diffusionRateBrush: 0,
diffusionRateTrails: 0,
});
expect(Number.isFinite(uniformValues[0])).toBe(true);
expect(Number.isFinite(uniformValues[2])).toBe(true);
expect(uniformValues[0]).toBeGreaterThan(0);
expect(uniformValues[2]).toBeGreaterThan(0);
});
it('passes valid diffusion rates through as inverse values', () => {
expect(getSafeInverseDiffusionRate(2)).toBe(0.5);
expect(getSafeInverseDiffusionRate(0.25)).toBe(4);
});
});

View file

@ -1,3 +1,4 @@
import { appConfig } from '../../config';
import {
createCachedFloat32BufferWrite,
writeFloat32BufferIfChanged,
@ -8,8 +9,36 @@ import { CommonState } from '../common-state/common-state';
import shader from './diffuse.wgsl?raw';
import { DiffusionSettings } from './diffusion-settings';
const MIN_DIFFUSION_RATE = appConfig.pipelines.diffusion.minDiffusionRate;
type DiffusionUniformSettings = Pick<
DiffusionSettings,
'diffusionRateTrails' | 'decayRateTrails' | 'diffusionRateBrush' | 'decayRateBrush'
>;
export const getSafeInverseDiffusionRate = (diffusionRate: number): number =>
1 /
(Number.isFinite(diffusionRate) && diffusionRate > MIN_DIFFUSION_RATE
? diffusionRate
: MIN_DIFFUSION_RATE);
export const setDiffusionUniformValues = (
target: Float32Array,
{
diffusionRateTrails,
decayRateTrails,
diffusionRateBrush,
decayRateBrush,
}: DiffusionUniformSettings
): void => {
target[0] = getSafeInverseDiffusionRate(diffusionRateTrails);
target[1] = decayRateTrails / 1000;
target[2] = getSafeInverseDiffusionRate(diffusionRateBrush);
target[3] = decayRateBrush / 1000;
};
export class DiffusionPipeline {
private static readonly UNIFORM_COUNT = 5;
private static readonly UNIFORM_COUNT = 4;
private readonly bindGroupLayout: GPUBindGroupLayout;
private readonly pipeline: GPURenderPipeline;
@ -18,10 +47,10 @@ export class DiffusionPipeline {
private readonly uniformCache = createCachedFloat32BufferWrite(
DiffusionPipeline.UNIFORM_COUNT
);
private readonly sampler: GPUSampler;
private readonly vertexBuffer: GPUBuffer;
private bindGroup?: GPUBindGroup;
private previousTrailMapIn?: GPUTextureView;
private readonly bindGroupsByInput = new WeakMap<GPUTextureView, GPUBindGroup>();
public constructor(
private readonly device: GPUDevice,
@ -57,6 +86,11 @@ export class DiffusionPipeline {
size: DiffusionPipeline.UNIFORM_COUNT * Float32Array.BYTES_PER_ELEMENT,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
this.sampler = this.device.createSampler({
magFilter: 'linear',
minFilter: 'linear',
});
}
public setParameters({
@ -64,13 +98,13 @@ export class DiffusionPipeline {
decayRateTrails,
diffusionRateBrush,
decayRateBrush,
anisotropy,
}: DiffusionSettings) {
this.uniformValues[0] = 1 / diffusionRateTrails;
this.uniformValues[1] = decayRateTrails / 1000;
this.uniformValues[2] = 1 / diffusionRateBrush;
this.uniformValues[3] = decayRateBrush / 1000;
this.uniformValues[4] = anisotropy;
setDiffusionUniformValues(this.uniformValues, {
diffusionRateTrails,
decayRateTrails,
diffusionRateBrush,
decayRateBrush,
});
writeFloat32BufferIfChanged(
this.device,
this.uniforms,
@ -84,7 +118,7 @@ export class DiffusionPipeline {
trailMapIn: GPUTextureView,
trailMapOut: GPUTextureView
) {
this.ensureBindGroupExists(trailMapIn);
const bindGroup = this.getBindGroup(trailMapIn);
const renderPassDescriptor: GPURenderPassDescriptor = {
colorAttachments: [
@ -101,38 +135,39 @@ export class DiffusionPipeline {
passEncoder.setPipeline(this.pipeline);
passEncoder.setVertexBuffer(0, this.vertexBuffer);
this.commonState.execute(passEncoder);
passEncoder.setBindGroup(1, this.bindGroup);
passEncoder.setBindGroup(1, bindGroup);
passEncoder.draw(4, 1);
passEncoder.end();
}
private ensureBindGroupExists(trailMapIn: GPUTextureView) {
if (this.previousTrailMapIn !== trailMapIn) {
this.bindGroup = this.device.createBindGroup({
layout: this.bindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: this.uniforms,
},
},
{
binding: 1,
resource: this.device.createSampler({
magFilter: 'linear',
minFilter: 'linear',
}),
},
{
binding: 2,
resource: trailMapIn,
},
],
});
this.previousTrailMapIn = trailMapIn;
private getBindGroup(trailMapIn: GPUTextureView): GPUBindGroup {
const cached = this.bindGroupsByInput.get(trailMapIn);
if (cached) {
return cached;
}
const bindGroup = this.device.createBindGroup({
layout: this.bindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: this.uniforms,
},
},
{
binding: 1,
resource: this.sampler,
},
{
binding: 2,
resource: trailMapIn,
},
],
});
this.bindGroupsByInput.set(trailMapIn, bindGroup);
return bindGroup;
}
public destroy() {

View file

@ -3,4 +3,5 @@ export interface DiffusionSettings {
decayRateTrails: number;
diffusionRateBrush: number;
decayRateBrush: number;
brushEffectDuration: number;
}

View file

@ -0,0 +1,244 @@
import { vec2 } from 'gl-matrix';
import { appConfig } from '../../config';
import {
createCachedFloat32BufferWrite,
writeFloat32BufferIfChanged,
} from '../../utils/graphics/cached-buffer-write';
import { getWorkgroupCounts } from '../../utils/graphics/get-workgroup-counts';
import { smartCompile } from '../../utils/graphics/smart-compile';
import agentSchema from '../agents/agent-generation/agent-schema.wgsl?raw';
import { CommonState } from '../common-state/common-state';
import shader from './eraser-agent.wgsl?raw';
interface LineSegment {
from: vec2;
to: vec2;
}
const shaderWithConfig = shader.replace(
'const MAX_SEGMENT_COUNT = 384u;',
`const MAX_SEGMENT_COUNT = ${Math.round(appConfig.pipelines.eraser.maxSegmentCount)}u;`
);
export class EraserAgentPipeline {
private static readonly WORKGROUP_SIZE = appConfig.pipelines.eraser.workgroupSize;
private static readonly UNIFORM_COUNT = 4;
private static readonly MAX_SEGMENT_COUNT = appConfig.pipelines.eraser.maxSegmentCount;
private static readonly SEGMENT_FLOAT_COUNT =
appConfig.pipelines.eraser.segmentFloatCount;
private readonly bindGroupLayout: GPUBindGroupLayout;
private readonly bindGroup: GPUBindGroup;
private readonly pipeline: GPUComputePipeline;
private readonly uniforms: GPUBuffer;
private readonly uniformValues = new Float32Array(EraserAgentPipeline.UNIFORM_COUNT);
private readonly uniformCache = createCachedFloat32BufferWrite(
EraserAgentPipeline.UNIFORM_COUNT
);
private readonly segmentsBuffer: GPUBuffer;
private readonly segmentUploadData = new Float32Array(
EraserAgentPipeline.MAX_SEGMENT_COUNT * EraserAgentPipeline.SEGMENT_FLOAT_COUNT
);
private linePoints: Array<vec2> = [];
private lineSegments: Array<LineSegment> = [];
private actualSegments: Array<LineSegment> = [];
private segmentCount = 0;
private agentCount = 0;
public constructor(
private readonly device: GPUDevice,
private readonly commonState: CommonState,
private readonly agentsBuffer: GPUBuffer
) {
this.bindGroupLayout = device.createBindGroupLayout({
entries: [
{
binding: 0,
visibility: GPUShaderStage.COMPUTE,
buffer: {
type: 'uniform',
},
},
{
binding: 1,
visibility: GPUShaderStage.COMPUTE,
buffer: {
type: 'storage',
},
},
{
binding: 2,
visibility: GPUShaderStage.COMPUTE,
buffer: {
type: 'read-only-storage',
},
},
],
});
this.uniforms = this.device.createBuffer({
size: EraserAgentPipeline.UNIFORM_COUNT * Float32Array.BYTES_PER_ELEMENT,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
this.segmentsBuffer = this.device.createBuffer({
size:
EraserAgentPipeline.MAX_SEGMENT_COUNT *
EraserAgentPipeline.SEGMENT_FLOAT_COUNT *
Float32Array.BYTES_PER_ELEMENT,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
this.bindGroup = this.device.createBindGroup({
layout: this.bindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: this.uniforms,
},
},
{
binding: 1,
resource: {
buffer: this.agentsBuffer,
},
},
{
binding: 2,
resource: {
buffer: this.segmentsBuffer,
},
},
],
});
this.pipeline = device.createComputePipeline({
layout: device.createPipelineLayout({
bindGroupLayouts: [commonState.bindGroupLayout, this.bindGroupLayout],
}),
compute: {
module: smartCompile(
device,
CommonState.shaderCode,
agentSchema,
shaderWithConfig
),
entryPoint: 'main',
},
});
}
public addSwipe(position: vec2): void {
const previousPosition = this.linePoints[this.linePoints.length - 1] ?? position;
this.addSwipeSegment(previousPosition, position);
this.linePoints.push(vec2.clone(position));
}
public addSwipeSegment(from: vec2, to: vec2): void {
this.lineSegments.push({
from: vec2.clone(from),
to: vec2.clone(to),
});
}
public clearSwipes(): void {
this.linePoints.length = 0;
this.lineSegments.length = 0;
this.actualSegments.length = 0;
this.segmentCount = 0;
}
public setParameters({
agentCount,
eraserSize,
}: {
agentCount: number;
eraserSize: number;
}): void {
this.agentCount = agentCount;
this.actualSegments = this.lineSegments.slice();
this.lineSegments.length = 0;
if (this.actualSegments.length > EraserAgentPipeline.MAX_SEGMENT_COUNT) {
this.actualSegments = EraserAgentPipeline.subsampleSegments(this.actualSegments);
}
this.segmentCount = Math.max(0, this.actualSegments.length);
const eraserRadius = eraserSize / 2;
this.uniformValues[0] = eraserRadius;
this.uniformValues[1] = this.segmentCount;
this.uniformValues[2] = agentCount;
this.uniformValues[3] = eraserRadius * eraserRadius;
writeFloat32BufferIfChanged(
this.device,
this.uniforms,
this.uniformValues,
this.uniformCache
);
if (this.segmentCount === 0) {
return;
}
for (let i = 0; i < this.segmentCount; i++) {
const { from, to } = this.actualSegments[i];
const offset = i * EraserAgentPipeline.SEGMENT_FLOAT_COUNT;
this.segmentUploadData[offset] = from[0];
this.segmentUploadData[offset + 1] = from[1];
this.segmentUploadData[offset + 2] = to[0];
this.segmentUploadData[offset + 3] = to[1];
}
this.device.queue.writeBuffer(
this.segmentsBuffer,
0,
this.segmentUploadData,
0,
this.segmentCount * EraserAgentPipeline.SEGMENT_FLOAT_COUNT
);
}
public execute(commandEncoder: GPUCommandEncoder): void {
if (this.segmentCount === 0 || this.agentCount === 0) {
return;
}
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(this.pipeline);
this.commonState.execute(passEncoder);
passEncoder.setBindGroup(1, this.bindGroup);
passEncoder.dispatchWorkgroups(
...getWorkgroupCounts(
this.device,
this.agentCount,
EraserAgentPipeline.WORKGROUP_SIZE
)
);
passEncoder.end();
}
public destroy(): void {
this.uniforms.destroy();
this.segmentsBuffer.destroy();
}
private static subsampleSegments(segments: Array<LineSegment>): Array<LineSegment> {
if (segments.length <= EraserAgentPipeline.MAX_SEGMENT_COUNT) {
return segments;
}
const result: Array<LineSegment> = [];
for (let i = 0; i < EraserAgentPipeline.MAX_SEGMENT_COUNT; i++) {
const index = Math.round(
(i * (segments.length - 1)) / (EraserAgentPipeline.MAX_SEGMENT_COUNT - 1)
);
result.push(segments[index]);
}
return result;
}
}

View file

@ -0,0 +1,63 @@
struct Settings {
eraserRadius: f32,
segmentCount: f32,
agentCount: f32,
eraserRadiusSquared: f32,
};
const MAX_SEGMENT_COUNT = 384u;
@group(1) @binding(0) var<uniform> settings: Settings;
@group(1) @binding(2) var<storage, read> segments: array<vec4<f32>>;
@compute @workgroup_size(64)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) workgroup_count: vec3<u32>
) {
let id = get_id(global_id, workgroup_count);
if id >= u32(settings.agentCount) {
return;
}
var agent = agents[id];
if agent.colorIndex < 0.0 {
return;
}
for (var i = 0u; i < MAX_SEGMENT_COUNT; i++) {
if i >= u32(settings.segmentCount) {
break;
}
let segment = segments[i];
let distanceSquared = distanceSquaredFromLine(
agent.position,
segment.xy,
segment.zw
);
if distanceSquared <= settings.eraserRadiusSquared {
agent.position = vec2<f32>(-1.0, -1.0);
agent.targetPosition = vec2<f32>(-1.0, -1.0);
agent.colorIndex = -1.0;
agents[id] = agent;
return;
}
}
}
fn distanceSquaredFromLine(position: vec2<f32>, start: vec2<f32>, end: vec2<f32>) -> f32 {
let pa = position - start;
let direction = end - start;
let denominator = dot(direction, direction);
if denominator <= 0.0001 {
return dot(pa, pa);
}
let q = clamp(dot(pa, direction) / denominator, 0.0, 1.0);
let nearestOffset = pa - direction * q;
return dot(nearestOffset, nearestOffset);
}

View file

@ -0,0 +1,333 @@
import { vec2 } from 'gl-matrix';
import { appConfig } from '../../config';
import { clamp } from '../../utils/clamp';
import {
createCachedFloat32BufferWrite,
writeFloat32BufferIfChanged,
} from '../../utils/graphics/cached-buffer-write';
import { smartCompile } from '../../utils/graphics/smart-compile';
import { CommonState } from '../common-state/common-state';
import shader from './eraser-texture.wgsl?raw';
interface LineSegment {
from: vec2;
to: vec2;
}
export class EraserTexturePipeline {
private static readonly UNIFORM_COUNT = 4;
private static readonly MAX_LINE_COUNT = appConfig.pipelines.eraser.maxTextureLineCount;
private static readonly VERTICES_PER_LINE_SEGMENT = 6;
private static readonly ATTRIBUTES_PER_LINE_SEGMENT = 6;
private readonly bindGroupLayout: GPUBindGroupLayout;
private readonly bindGroup: GPUBindGroup;
private readonly pipeline: GPURenderPipeline;
private readonly uniforms: GPUBuffer;
private readonly uniformValues = new Float32Array(EraserTexturePipeline.UNIFORM_COUNT);
private readonly uniformCache = createCachedFloat32BufferWrite(
EraserTexturePipeline.UNIFORM_COUNT
);
private readonly vertexBuffer: GPUBuffer;
private readonly vertexUploadData = new Float32Array(
EraserTexturePipeline.MAX_LINE_COUNT *
EraserTexturePipeline.VERTICES_PER_LINE_SEGMENT *
EraserTexturePipeline.ATTRIBUTES_PER_LINE_SEGMENT
);
private linePoints: Array<vec2> = [];
private lineSegments: Array<LineSegment> = [];
private actualSegments: Array<LineSegment> = [];
public constructor(
private readonly device: GPUDevice,
private readonly commonState: CommonState
) {
this.bindGroupLayout = device.createBindGroupLayout({
entries: [
{
binding: 0,
visibility: GPUShaderStage.FRAGMENT,
buffer: {
type: 'uniform',
},
},
],
});
this.vertexBuffer = device.createBuffer({
size:
EraserTexturePipeline.MAX_LINE_COUNT *
EraserTexturePipeline.VERTICES_PER_LINE_SEGMENT *
EraserTexturePipeline.ATTRIBUTES_PER_LINE_SEGMENT *
Float32Array.BYTES_PER_ELEMENT,
usage: GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST,
});
this.pipeline = device.createRenderPipeline({
layout: device.createPipelineLayout({
bindGroupLayouts: [commonState.bindGroupLayout, this.bindGroupLayout],
}),
vertex: {
module: smartCompile(device, CommonState.shaderCode, shader),
entryPoint: 'vertex',
buffers: [
{
arrayStride: Float32Array.BYTES_PER_ELEMENT * 6,
attributes: [
{
shaderLocation: 0,
format: 'float32x2',
offset: 0,
},
{
shaderLocation: 1,
format: 'float32x2',
offset: Float32Array.BYTES_PER_ELEMENT * 2,
},
{
shaderLocation: 2,
format: 'float32x2',
offset: Float32Array.BYTES_PER_ELEMENT * 4,
},
],
},
],
},
fragment: {
module: smartCompile(device, CommonState.shaderCode, shader),
entryPoint: 'fragment',
targets: [
{
format: 'rgba16float',
},
],
},
primitive: {
topology: 'triangle-list',
},
});
this.uniforms = this.device.createBuffer({
size: EraserTexturePipeline.UNIFORM_COUNT * Float32Array.BYTES_PER_ELEMENT,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
this.bindGroup = this.device.createBindGroup({
layout: this.bindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: this.uniforms,
},
},
],
});
}
public addSwipe(position: vec2): void {
const previousPosition = this.linePoints[this.linePoints.length - 1] ?? position;
this.addSwipeSegment(previousPosition, position);
this.linePoints.push(vec2.clone(position));
}
public addSwipeSegment(from: vec2, to: vec2): void {
this.lineSegments.push({
from: vec2.clone(from),
to: vec2.clone(to),
});
}
public clearSwipes(): void {
this.linePoints.length = 0;
this.lineSegments.length = 0;
this.actualSegments.length = 0;
}
public setParameters({ eraserSize }: { eraserSize: number }): void {
const eraserRadius = eraserSize / 2;
this.uniformValues[0] = eraserRadius;
this.uniformValues[1] = eraserRadius * eraserRadius;
this.uniformValues[2] = 0;
this.uniformValues[3] = 0;
writeFloat32BufferIfChanged(
this.device,
this.uniforms,
this.uniformValues,
this.uniformCache
);
this.actualSegments = this.lineSegments.slice();
this.lineSegments.length = 0;
if (this.actualSegments.length === 0) {
return;
}
if (this.actualSegments.length > EraserTexturePipeline.MAX_LINE_COUNT) {
this.actualSegments = EraserTexturePipeline.subsampleSegments(this.actualSegments);
}
const lineCount = this.lineCount;
let floatOffset = 0;
for (let i = 0; i < lineCount; i++) {
const segment = this.actualSegments[i];
floatOffset = this.writeSegmentVertices(
this.vertexUploadData,
floatOffset,
segment.from,
segment.to,
eraserRadius
);
}
this.device.queue.writeBuffer(
this.vertexBuffer,
0,
this.vertexUploadData,
0,
floatOffset
);
}
public execute(commandEncoder: GPUCommandEncoder, textureOut: GPUTextureView): void {
if (this.lineCount === 0) {
return;
}
const renderPassDescriptor: GPURenderPassDescriptor = {
colorAttachments: [
{
view: textureOut,
loadOp: 'load',
storeOp: 'store',
},
],
};
const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor);
passEncoder.setPipeline(this.pipeline);
this.commonState.execute(passEncoder);
passEncoder.setBindGroup(1, this.bindGroup);
passEncoder.setVertexBuffer(0, this.vertexBuffer);
passEncoder.draw(EraserTexturePipeline.VERTICES_PER_LINE_SEGMENT * this.lineCount, 1);
passEncoder.end();
}
public destroy(): void {
this.vertexBuffer.destroy();
this.uniforms.destroy();
}
private static subsampleSegments(segments: Array<LineSegment>): Array<LineSegment> {
if (segments.length <= EraserTexturePipeline.MAX_LINE_COUNT) {
return segments;
}
const result: Array<LineSegment> = [];
for (let i = 0; i < EraserTexturePipeline.MAX_LINE_COUNT; i++) {
const index = Math.round(
(i * (segments.length - 1)) / (EraserTexturePipeline.MAX_LINE_COUNT - 1)
);
result.push(segments[index]);
}
return result;
}
private writeSegmentVertices(
target: Float32Array,
offset: number,
from: vec2,
to: vec2,
width: number
): number {
const dx = to[0] - from[0];
const dy = to[1] - from[1];
const length = Math.hypot(dx, dy);
const directionX = length > 0 ? dx / length : 1;
const directionY = length > 0 ? dy / length : 0;
const scaledDirectionX = directionX * width;
const scaledDirectionY = directionY * width;
const perpendicularX = directionY * width;
const perpendicularY = -directionX * width;
const startX = from[0] - scaledDirectionX;
const startY = from[1] - scaledDirectionY;
const endX = to[0] + scaledDirectionX;
const endY = to[1] + scaledDirectionY;
offset = this.writeVertex(
target,
offset,
startX + perpendicularX,
startY + perpendicularY,
from,
to
);
offset = this.writeVertex(
target,
offset,
startX - perpendicularX,
startY - perpendicularY,
from,
to
);
offset = this.writeVertex(
target,
offset,
endX + perpendicularX,
endY + perpendicularY,
from,
to
);
offset = this.writeVertex(
target,
offset,
startX - perpendicularX,
startY - perpendicularY,
from,
to
);
offset = this.writeVertex(
target,
offset,
endX + perpendicularX,
endY + perpendicularY,
from,
to
);
return this.writeVertex(
target,
offset,
endX - perpendicularX,
endY - perpendicularY,
from,
to
);
}
private writeVertex(
target: Float32Array,
offset: number,
screenX: number,
screenY: number,
from: vec2,
to: vec2
): number {
target[offset++] = screenX;
target[offset++] = screenY;
target[offset++] = from[0];
target[offset++] = from[1];
target[offset++] = to[0];
target[offset++] = to[1];
return offset;
}
private get lineCount(): number {
return clamp(this.actualSegments.length, 0, EraserTexturePipeline.MAX_LINE_COUNT);
}
}

View file

@ -0,0 +1,53 @@
struct Settings {
eraserRadius: f32,
eraserRadiusSquared: f32,
padding1: f32,
padding2: f32,
};
@group(1) @binding(0) var<uniform> settings: Settings;
struct VertexOutput {
@builtin(position) position: vec4<f32>,
@location(0) screenPosition: vec2<f32>,
@location(1) start: vec2<f32>,
@location(2) end: vec2<f32>
}
@vertex
fn vertex(
@location(0) screenPosition: vec2<f32>,
@location(1) @interpolate(flat) start: vec2<f32>,
@location(2) @interpolate(flat) end: vec2<f32>
) -> VertexOutput {
let uv = screenPosition / state.size;
let position = vec2(uv.x * 2.0 - 1.0, 1.0 - uv.y * 2.0);
return VertexOutput(vec4(position, 0.0, 1.0), screenPosition, start, end);
}
@fragment
fn fragment(
@location(0) screenPosition: vec2<f32>,
@location(1) start: vec2<f32>,
@location(2) end: vec2<f32>
) -> @location(0) vec4<f32> {
if distanceSquaredFromLine(screenPosition, start, end) > settings.eraserRadiusSquared {
discard;
}
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
}
fn distanceSquaredFromLine(position: vec2<f32>, start: vec2<f32>, end: vec2<f32>) -> f32 {
let pa = position - start;
let direction = end - start;
let denominator = dot(direction, direction);
if denominator <= 0.0001 {
return dot(pa, pa);
}
let q = clamp(dot(pa, direction) / denominator, 0.0, 1.0);
let nearestOffset = pa - direction * q;
return dot(nearestOffset, nearestOffset);
}

View file

@ -1,5 +1,7 @@
import { vec3 } from 'gl-matrix';
import {
createCachedFloat32BufferWrite,
writeFloat32BufferIfChanged,
} from '../../utils/graphics/cached-buffer-write';
import { setUpFullScreenQuad } from '../../utils/graphics/full-screen-quad';
import { smartCompile } from '../../utils/graphics/smart-compile';
import { CommonState } from '../common-state/common-state';
@ -7,15 +9,23 @@ import { RenderSettings } from './render-settings';
import shader from './render.wgsl?raw';
export class RenderPipeline {
private static readonly UNIFORM_COUNT = 13;
private static readonly UNIFORM_COUNT = 20;
private readonly bindGroupLayout: GPUBindGroupLayout;
private readonly pipeline: GPURenderPipeline;
private readonly canvasPipeline: GPURenderPipeline;
private readonly exportPipeline: GPURenderPipeline;
private readonly sampler: GPUSampler;
private readonly uniforms: GPUBuffer;
private readonly uniformValues = new Float32Array(RenderPipeline.UNIFORM_COUNT);
private readonly uniformCache = createCachedFloat32BufferWrite(
RenderPipeline.UNIFORM_COUNT
);
private readonly vertexBuffer: GPUBuffer;
private bindGroup?: GPUBindGroup;
private previousColorTexture?: GPUTextureView;
private readonly bindGroupsByTexture = new WeakMap<
GPUTextureView,
WeakMap<GPUTextureView, GPUBindGroup>
>();
public constructor(
private readonly context: GPUCanvasContext,
@ -27,104 +37,179 @@ export class RenderPipeline {
const { buffer, vertex } = setUpFullScreenQuad(device);
this.vertexBuffer = buffer;
this.pipeline = device.createRenderPipeline({
layout: device.createPipelineLayout({
bindGroupLayouts: [commonState.bindGroupLayout, this.bindGroupLayout],
}),
vertex,
fragment: {
module: smartCompile(device, CommonState.shaderCode, shader),
entryPoint: 'fragment',
targets: [
{
format: navigator.gpu.getPreferredCanvasFormat(),
},
],
},
primitive: {
topology: 'triangle-strip',
},
this.sampler = device.createSampler({
magFilter: 'linear',
minFilter: 'linear',
});
const format = navigator.gpu.getPreferredCanvasFormat();
this.canvasPipeline = this.createPipeline(format, vertex);
this.exportPipeline = this.createPipeline(format, vertex);
this.uniforms = this.device.createBuffer({
size: RenderPipeline.UNIFORM_COUNT * Float32Array.BYTES_PER_ELEMENT,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
}
private createPipeline(
format: GPUTextureFormat,
vertex: GPUVertexState
): GPURenderPipeline {
return this.device.createRenderPipeline({
layout: this.device.createPipelineLayout({
bindGroupLayouts: [this.commonState.bindGroupLayout, this.bindGroupLayout],
}),
vertex,
fragment: {
module: smartCompile(this.device, CommonState.shaderCode, shader),
entryPoint: 'fragment',
targets: [
{
format,
},
],
},
primitive: {
topology: 'triangle-strip',
},
});
}
public setParameters({
brushColor,
evenGenerationColor,
oddGenerationColor,
channelColors,
backgroundColor,
cameraCenter,
cameraZoom,
clarity,
}: RenderSettings & {
brushColor: vec3;
evenGenerationColor: vec3;
oddGenerationColor: vec3;
channelColors: Array<[number, number, number]>;
backgroundColor: [number, number, number];
cameraCenter: [number, number];
cameraZoom: number;
}) {
this.device.queue.writeBuffer(
const [a, b, c] = channelColors;
this.uniformValues[0] = a[0];
this.uniformValues[1] = a[1];
this.uniformValues[2] = a[2];
this.uniformValues[3] = 0;
this.uniformValues[4] = b[0];
this.uniformValues[5] = b[1];
this.uniformValues[6] = b[2];
this.uniformValues[7] = 0;
this.uniformValues[8] = c[0];
this.uniformValues[9] = c[1];
this.uniformValues[10] = c[2];
this.uniformValues[11] = 0;
this.uniformValues[12] = backgroundColor[0];
this.uniformValues[13] = backgroundColor[1];
this.uniformValues[14] = backgroundColor[2];
this.uniformValues[15] = clarity;
this.uniformValues[16] = cameraCenter[0];
this.uniformValues[17] = cameraCenter[1];
this.uniformValues[18] = cameraZoom;
this.uniformValues[19] = 0;
writeFloat32BufferIfChanged(
this.device,
this.uniforms,
0,
new Float32Array([
...brushColor,
0, //padding
...evenGenerationColor,
0, //padding
...oddGenerationColor,
clarity,
])
this.uniformValues,
this.uniformCache
);
}
public execute(commandEncoder: GPUCommandEncoder, colorTexture: GPUTextureView) {
this.ensureBindGroupExists(colorTexture);
public execute(
commandEncoder: GPUCommandEncoder,
colorTexture: GPUTextureView,
sourceTexture: GPUTextureView
) {
const bindGroup = this.getBindGroup(colorTexture, sourceTexture);
const renderPassDescriptor: GPURenderPassDescriptor = {
colorAttachments: [
{
view: this.context.getCurrentTexture().createView(),
clearValue: { r: 0, g: 1, b: 1, a: 1 },
clearValue: { r: 0, g: 0, b: 0, a: 1 },
loadOp: 'clear',
storeOp: 'store',
},
],
};
const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor);
passEncoder.setPipeline(this.pipeline);
passEncoder.setPipeline(this.canvasPipeline);
this.commonState.execute(passEncoder);
passEncoder.setVertexBuffer(0, this.vertexBuffer);
passEncoder.setBindGroup(1, this.bindGroup);
passEncoder.setBindGroup(1, bindGroup);
passEncoder.draw(4, 1);
passEncoder.end();
}
private ensureBindGroupExists(colorTexture: GPUTextureView) {
if (this.previousColorTexture !== colorTexture) {
this.bindGroup = this.device.createBindGroup({
layout: this.bindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: this.uniforms,
},
},
{
binding: 1,
resource: this.device.createSampler({
magFilter: 'linear',
minFilter: 'linear',
}),
},
{
binding: 2,
resource: colorTexture,
},
],
});
public executeToView(
commandEncoder: GPUCommandEncoder,
colorTexture: GPUTextureView,
sourceTexture: GPUTextureView,
outputTexture: GPUTextureView
) {
const bindGroup = this.getBindGroup(colorTexture, sourceTexture);
this.previousColorTexture = colorTexture;
const passEncoder = commandEncoder.beginRenderPass({
colorAttachments: [
{
view: outputTexture,
clearValue: { r: 0, g: 0, b: 0, a: 1 },
loadOp: 'clear',
storeOp: 'store',
},
],
});
passEncoder.setPipeline(this.exportPipeline);
this.commonState.execute(passEncoder);
passEncoder.setVertexBuffer(0, this.vertexBuffer);
passEncoder.setBindGroup(1, bindGroup);
passEncoder.draw(4, 1);
passEncoder.end();
}
private getBindGroup(
colorTexture: GPUTextureView,
sourceTexture: GPUTextureView
): GPUBindGroup {
let sourceTextureCache = this.bindGroupsByTexture.get(colorTexture);
if (!sourceTextureCache) {
sourceTextureCache = new WeakMap<GPUTextureView, GPUBindGroup>();
this.bindGroupsByTexture.set(colorTexture, sourceTextureCache);
}
const cached = sourceTextureCache.get(sourceTexture);
if (cached) {
return cached;
}
const bindGroup = this.device.createBindGroup({
layout: this.bindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: this.uniforms,
},
},
{
binding: 1,
resource: this.sampler,
},
{
binding: 2,
resource: colorTexture,
},
{
binding: 3,
resource: sourceTexture,
},
],
});
sourceTextureCache.set(sourceTexture, bindGroup);
return bindGroup;
}
public destroy() {
@ -156,6 +241,13 @@ export class RenderPipeline {
sampleType: 'float',
},
},
{
binding: 3,
visibility: GPUShaderStage.FRAGMENT,
texture: {
sampleType: 'float',
},
},
],
};
}

View file

@ -0,0 +1,193 @@
import { describe, expect, it } from 'vitest';
import compactionShader from './agents/agent-generation/agent-compaction.wgsl?raw';
import countingShader from './agents/agent-generation/agent-counting.wgsl?raw';
import { AgentGenerationPipeline } from './agents/agent-generation/agent-generation-pipeline';
import resizeShader from './agents/agent-generation/agent-resize.wgsl?raw';
import { AgentPipeline } from './agents/agent-pipeline';
import agentShader from './agents/agent.wgsl?raw';
import { BrushPipeline } from './brush/brush-pipeline';
import brushShader from './brush/brush.wgsl?raw';
import { CommonState } from './common-state/common-state';
import { CopyPipeline } from './copy/copy-pipeline';
import copyShader from './copy/copy.wgsl?raw';
import diffusionShader from './diffusion/diffuse.wgsl?raw';
import { DiffusionPipeline } from './diffusion/diffusion-pipeline';
import { EraserAgentPipeline } from './eraser/eraser-agent-pipeline';
import eraserAgentShader from './eraser/eraser-agent.wgsl?raw';
import { EraserTexturePipeline } from './eraser/eraser-texture-pipeline';
import eraserTextureShader from './eraser/eraser-texture.wgsl?raw';
import { RenderPipeline } from './render/render-pipeline';
import renderShader from './render/render.wgsl?raw';
const wgslFloatCountsByType: Record<string, number> = {
f32: 1,
u32: 1,
'vec2<f32>': 2,
'vec3<f32>': 3,
'vec4<f32>': 4,
};
const stripComments = (source: string): string =>
source.replace(/\/\/.*$/gm, '').replace(/\/\*[\s\S]*?\*\//g, '');
const getStructFields = (source: string, structName: string) => {
const match = new RegExp(
`struct ${structName}\\s*\\{(?<body>[\\s\\S]*?)\\n\\s*\\}`
).exec(stripComments(source));
if (!match?.groups?.body) {
throw new Error(`${structName} struct was not found`);
}
return match.groups.body
.split('\n')
.map((line) => line.trim().replace(/,$/, ''))
.filter(Boolean)
.map((line) => {
const fieldMatch = /^(?<name>\w+):\s*(?<type>[^,]+)$/.exec(line);
if (!fieldMatch?.groups) {
throw new Error(`Unsupported WGSL struct field syntax: ${line}`);
}
return {
name: fieldMatch.groups.name,
type: fieldMatch.groups.type,
};
});
};
const countUniformScalars = (source: string, structName: string): number =>
getStructFields(source, structName).reduce((sum, field) => {
const count = wgslFloatCountsByType[field.type];
if (!count) {
throw new Error(`Unsupported WGSL uniform field type: ${field.type}`);
}
return sum + count;
}, 0);
const getUniformCount = (pipeline: unknown): number =>
(pipeline as { UNIFORM_COUNT: number }).UNIFORM_COUNT;
const expectStructUniformLayout = ({
pipeline,
source,
structName,
fieldNames,
}: {
pipeline: unknown;
source: string;
structName: string;
fieldNames: Array<string>;
}) => {
const fields = getStructFields(source, structName);
expect(fields.map((field) => field.name)).toEqual(fieldNames);
expect(countUniformScalars(source, structName)).toBe(getUniformCount(pipeline));
};
describe('WGSL uniform layout contracts', () => {
it('keeps shared common-state uniforms aligned with WGSL', () => {
expectStructUniformLayout({
pipeline: CommonState,
source: CommonState.shaderCode,
structName: 'State',
fieldNames: ['size', 'deltaTime', 'time'],
});
});
it('keeps render and simulation uniforms aligned with WGSL', () => {
expectStructUniformLayout({
pipeline: AgentPipeline,
source: agentShader,
structName: 'Settings',
fieldNames: [
'moveRate',
'turnRate',
'sensorAngle',
'sensorOffset',
'turnWhenLost',
'individualTrailWeight',
'agentCount',
'introProgress',
],
});
expectStructUniformLayout({
pipeline: BrushPipeline,
source: brushShader,
structName: 'Settings',
fieldNames: [
'brushSize',
'brushSizeVariation',
'padding0',
'padding1',
'brushValue',
],
});
expectStructUniformLayout({
pipeline: DiffusionPipeline,
source: diffusionShader,
structName: 'Settings',
fieldNames: [
'inverseDiffusionRateTrails',
'decayRateTrails',
'inverseDiffusionRateBrush',
'decayRateBrush',
],
});
expectStructUniformLayout({
pipeline: RenderPipeline,
source: renderShader,
structName: 'Settings',
fieldNames: [
'colorA',
'backgroundColorPadding0',
'colorB',
'backgroundColorPadding1',
'colorC',
'backgroundColorPadding2',
'backgroundColor',
'clarity',
'cameraCenter',
'cameraZoom',
'padding0',
],
});
});
it('keeps eraser uniforms aligned with WGSL', () => {
expectStructUniformLayout({
pipeline: EraserAgentPipeline,
source: eraserAgentShader,
structName: 'Settings',
fieldNames: ['eraserRadius', 'segmentCount', 'agentCount', 'eraserRadiusSquared'],
});
expectStructUniformLayout({
pipeline: EraserTexturePipeline,
source: eraserTextureShader,
structName: 'Settings',
fieldNames: ['eraserRadius', 'eraserRadiusSquared', 'padding1', 'padding2'],
});
});
it('keeps copy uniforms aligned with WGSL', () => {
const match = /var<uniform>\s+sourceScaler:\s*(?<type>[^;]+);/.exec(copyShader);
expect(match?.groups?.type).toBe('vec2<f32>');
expect(wgslFloatCountsByType[match?.groups?.type ?? '']).toBe(
getUniformCount(CopyPipeline)
);
});
it('keeps agent-generation uniforms large enough for every generation shader', () => {
const generationUniformCounts = [
countUniformScalars(countingShader, 'Settings'),
countUniformScalars(resizeShader, 'ResizeSettings'),
countUniformScalars(compactionShader, 'Settings'),
];
expect(Math.max(...generationUniformCounts)).toBe(
getUniformCount(AgentGenerationPipeline)
);
});
});