fleeting-garden/src/pipelines/diffusion/diffuse.wgsl
2026-05-21 07:43:10 +01:00

152 lines
5.3 KiB
WebGPU Shading Language

struct Settings {
inverseDiffusionRateTrails: f32,
decayRateTrails: f32,
diffusionNeighborScale: f32,
brushDecayAlphaMultiplier: f32,
brushDecayAlphaSubtract: f32,
padding0: f32,
padding1: f32,
padding2: f32,
};
const WORKGROUP_SIZE_X = 16u;
const WORKGROUP_SIZE_Y = 16u;
// One-pixel halo on each side so the 3x3 neighbourhood read in the main pass
// can be served from workgroup memory without bounds checks for interior tiles.
const TILE_SIZE_X = WORKGROUP_SIZE_X + 2u;
const TILE_SIZE_Y = WORKGROUP_SIZE_Y + 2u;
const TILE_TEXEL_COUNT = TILE_SIZE_X * TILE_SIZE_Y;
// 1.0 / 2^32, used to map a 32-bit hash to [0, 1).
const HASH_TO_UNIT_FLOAT: f32 = 2.3283064365386963e-10;
@group(0) @binding(0) var<uniform> settings: Settings;
@group(0) @binding(1) var trailMap: texture_2d<f32>;
@group(0) @binding(2) var trailMapOut: texture_storage_2d<rgba16float, write>;
var<workgroup> tile: array<vec4<f32>, 324>;
var<workgroup> tileTrailStrength: array<f32, 324>;
@compute @workgroup_size(16, 16)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>
) {
let textureSize = vec2<i32>(textureDimensions(trailMap, 0));
let textureSizeU32 = vec2<u32>(textureSize);
let localLinearIndex = local_id.y * WORKGROUP_SIZE_X + local_id.x;
let workgroupOrigin = workgroup_id.xy * vec2<u32>(WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y);
let isInteriorTile =
workgroupOrigin.x > 0u &&
workgroupOrigin.y > 0u &&
workgroupOrigin.x + WORKGROUP_SIZE_X < textureSizeU32.x &&
workgroupOrigin.y + WORKGROUP_SIZE_Y < textureSizeU32.y;
for (var tileIndex = localLinearIndex; tileIndex < TILE_TEXEL_COUNT; tileIndex += WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y) {
let tilePosition = vec2<u32>(tileIndex % TILE_SIZE_X, tileIndex / TILE_SIZE_X);
let unclampedSourcePixel = vec2<i32>(workgroupOrigin + tilePosition) - vec2<i32>(1, 1);
var sourcePixel = unclampedSourcePixel;
if !isInteriorTile {
sourcePixel = clamp(unclampedSourcePixel, vec2<i32>(0, 0), textureSize - vec2<i32>(1, 1));
}
let texel = textureLoad(trailMap, sourcePixel, 0);
tile[tileIndex] = texel;
tileTrailStrength[tileIndex] = length(texel.rgb);
}
workgroupBarrier();
let pixel = vec2<i32>(i32(global_id.x), i32(global_id.y));
let inBounds = pixel.x < textureSize.x && pixel.y < textureSize.y;
if !inBounds {
return;
}
let centerTilePosition = local_id.xy + vec2<u32>(1u, 1u);
let centerTileIndex = centerTilePosition.y * TILE_SIZE_X + centerTilePosition.x;
var current = tile[centerTileIndex];
let random = random_from_pixel(pixel);
let trailWeight = diffusion_weight(
random,
settings.inverseDiffusionRateTrails
);
current += (
propagate(centerTileIndex, -1, -1, current, trailWeight)
+ propagate(centerTileIndex, -1, 1, current, trailWeight)
+ propagate(centerTileIndex, 1, -1, current, trailWeight)
+ propagate(centerTileIndex, 1, 1, current, trailWeight)
+ propagate(centerTileIndex, -1, 0, current, trailWeight)
+ propagate(centerTileIndex, 0, -1, current, trailWeight)
+ propagate(centerTileIndex, 1, 0, current, trailWeight)
+ propagate(centerTileIndex, 0, 1, current, trailWeight)
) * settings.diffusionNeighborScale;
let decayed = clamp(vec4(
current.rgb * settings.decayRateTrails,
max(0, current.a * settings.brushDecayAlphaMultiplier - settings.brushDecayAlphaSubtract)
), vec4(0), vec4(1));
textureStore(trailMapOut, pixel, decayed);
}
fn propagate(
centerTileIndex: u32,
offsetX: i32,
offsetY: i32,
currentColor: vec4<f32>,
trailWeight: f32
) -> vec4<f32> {
let neighbourIndex = i32(centerTileIndex) + offsetY * i32(TILE_SIZE_X) + offsetX;
let neighbourTileIndex = u32(neighbourIndex);
let neighbour = tile[neighbourTileIndex];
let difference = clamp(neighbour - currentColor, vec4(0), vec4(1));
return vec4(
vec3(tileTrailStrength[neighbourTileIndex] * trailWeight),
neighbour.a * trailWeight
) * difference;
}
fn random_from_pixel(pixel: vec2<i32>) -> f32 {
let p = vec2<u32>(pixel);
var hash = p.x * 1664525u + p.y * 1013904223u + 374761393u;
hash = (hash ^ (hash >> 16u)) * 2246822519u;
hash = (hash ^ (hash >> 13u)) * 3266489917u;
hash = hash ^ (hash >> 16u);
return f32(hash) * HASH_TO_UNIT_FLOAT;
}
// Approximates pow(r, inverseRate) piecewise between powers (r, r^2, r^4, r^8, r^16)
// so we can vary diffusion sharpness without paying for a real pow() per pixel.
fn diffusion_weight(
r: f32,
inverseRate: f32
) -> f32 {
if inverseRate < 1.0 {
let rootApproximation = r / max(0.5 + r * 0.5, 0.0001);
return mix(
rootApproximation,
r,
clamp((inverseRate - 0.5) * 2.0, 0.0, 1.0)
);
}
let r2 = r * r;
if inverseRate < 2.0 {
return mix(r, r2, inverseRate - 1.0);
}
let r4 = r2 * r2;
if inverseRate < 4.0 {
// (inverseRate - 2.0) / (4.0 - 2.0)
return mix(r2, r4, (inverseRate - 2.0) * 0.5);
}
let r8 = r4 * r4;
if inverseRate < 8.0 {
// (inverseRate - 4.0) / (8.0 - 4.0)
return mix(r4, r8, (inverseRate - 4.0) * 0.25);
}
let r16 = r8 * r8;
// (inverseRate - 8.0) / (16.0 - 8.0); past 16, falls off as 16/inverseRate.
return mix(r8, r16, clamp((inverseRate - 8.0) * 0.125, 0.0, 1.0))
* min(1.0, 16.0 / inverseRate);
}