mosure / bevy_gaussian_splatting Goto Github PK
View Code? Open in Web Editor NEWbevy gaussian splatting render pipeline plugin
Home Page: https://mosure.github.io/bevy_gaussian_splatting/index.html?arg1=cactus.gcloud
License: MIT License
bevy gaussian splatting render pipeline plugin
Home Page: https://mosure.github.io/bevy_gaussian_splatting/index.html?arg1=cactus.gcloud
License: MIT License
Hi. Thank you for creating a Bevy implementation of Gaussian Splatting! ๐ฆ
I have tried running the example using this command: cargo run --release -- scenes/icecream.gcloud
It compiles and runs but the results looks off (splats too large) and some of the gaussians are flickering badly (even when the camera is not moving):
System information (Dell Precision 7560 notebook with Windows 11):
[Display]
DirectX version: 12.0
GPU processor: NVIDIA RTX A3000 Laptop GPU
Driver version: 536.45
Driver Type: DCH
Direct3D feature level: 12_1
CUDA Cores: 4096
Resizable BAR Yes
Dynamic Boost 2.0 Yes
WhisperMode 2.0 No
Advanced Optimus No
Maximum Graphics Power 105 W
Core clock: 1560 MHz
Memory data rate: 11.00 Gbps
Memory interface: 192-bit
Memory bandwidth: 264.05 GB/s
Total available graphics memory: 38504 MB
Dedicated video memory: 6144 MB GDDR6
System video memory: 0 MB
Shared system memory: 32360 MB
Video BIOS version: 94.04.42.40.06
IRQ: Not used
Bus: PCI Express x8 Gen4
Device ID: XXX
Part Number: XXX
[Components]
nvui.dll 8.17.15.3645 NVIDIA User Experience Driver Component
nvxdplcy.dll 8.17.15.3645 NVIDIA User Experience Driver Component
nvxdbat.dll 8.17.15.3645 NVIDIA User Experience Driver Component
nvxdapix.dll 8.17.15.3645 NVIDIA User Experience Driver Component
NVCPL.DLL 8.17.15.3645 NVIDIA User Experience Driver Component
nvCplUIR.dll 8.1.940.0 NVIDIA Control Panel
nvCplUI.exe 8.1.940.0 NVIDIA Control Panel
nvWSSR.dll 31.0.15.3645 NVIDIA Workstation Server
nvWSS.dll 31.0.15.3645 NVIDIA Workstation Server
nvViTvSR.dll 31.0.15.3645 NVIDIA Video Server
nvViTvS.dll 31.0.15.3645 NVIDIA Video Server
nvLicensingS.dll 6.14.15.3645 NVIDIA Licensing Server
nvDevToolSR.dll 31.0.15.3645 NVIDIA Licensing Server
nvDevToolS.dll 31.0.15.3645 NVIDIA 3D Settings Server
nvDispSR.dll 31.0.15.3645 NVIDIA Display Server
nvDispS.dll 31.0.15.3645 NVIDIA Display Server
PhysX 09.21.0713 NVIDIA PhysX
NVCUDA64.DLL 31.0.15.3645 NVIDIA CUDA 12.2.101 driver
nvGameSR.dll 31.0.15.3645 NVIDIA 3D Settings Server
nvGameS.dll 31.0.15.3645 NVIDIA 3D Settings Server
None => return RenderCommandResult::Failure,
};
pass.set_bind_group(2, &bind_groups.cloud_bind_group, &[]);
pass.set_bind_group(3, &bind_groups.sorted_bind_group, &[]);
pass.draw_indirect(&gpu_gaussian_cloud.draw_indirect_buffer, 0);
RenderCommandResult::Success
}
}
struct RadixSortNode {
gaussian_clouds: QueryState<(
&'static Handle<GaussianCloud>,
&'static GaussianCloudBindGroup
)>,
initialized: bool,
pipeline_idx: Option<u32>,
view_bind_group: QueryState<(
&'static GaussianViewBindGroup,
&'static ViewUniformOffset,
)>,
}
impl FromWorld for RadixSortNode {
fn from_world(world: &mut World) -> Self {
Self {
gaussian_clouds: world.query(),
initialized: false,
pipeline_idx: None,
view_bind_group: world.query(),
}
}
}
impl render_graph::Node for RadixSortNode {
fn update(&mut self, world: &mut World) {
let pipeline = world.resource::<GaussianCloudPipeline>();
let pipeline_cache = world.resource::<PipelineCache>();
if !self.initialized {
let mut pipelines_loaded = true;
for sort_pipeline in pipeline.radix_sort_pipelines.iter() {
if let CachedPipelineState::Ok(_) =
pipeline_cache.get_compute_pipeline_state(*sort_pipeline)
{
continue;
}
pipelines_loaded = false;
}
self.initialized = pipelines_loaded;
if !self.initialized {
return;
}
}
if self.pipeline_idx.is_none() {
self.pipeline_idx = Some(0);
} else {
self.pipeline_idx = Some((self.pipeline_idx.unwrap() + 1) % pipeline.radix_sort_pipelines.len() as u32);
}
self.gaussian_clouds.update_archetypes(world);
self.view_bind_group.update_archetypes(world);
}
fn run(
&self,
_graph: &mut render_graph::RenderGraphContext,
render_context: &mut RenderContext,
world: &World,
) -> Result<(), render_graph::NodeRunError> {
if !self.initialized || self.pipeline_idx.is_none() {
return Ok(());
}
let _idx = self.pipeline_idx.unwrap() as usize; // TODO: temporal sort
let pipeline_cache = world.resource::<PipelineCache>();
let pipeline = world.resource::<GaussianCloudPipeline>();
let gaussian_uniforms = world.resource::<GaussianUniformBindGroups>();
let command_encoder = render_context.command_encoder();
for (
view_bind_group,
view_uniform_offset,
) in self.view_bind_group.iter_manual(world) {
for (
cloud_handle,
cloud_bind_group
) in self.gaussian_clouds.iter_manual(world) {
let cloud = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap().get(cloud_handle).unwrap();
let radix_digit_places = ShaderDefines::default().radix_digit_places;
command_encoder.clear_buffer(
&cloud.sorting_global_buffer,
0,
None,
);
{
let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());
// TODO: view/global
pass.set_bind_group(
0,
&view_bind_group.value,
&[view_uniform_offset.offset],
);
pass.set_bind_group(
1,
gaussian_uniforms.base_bind_group.as_ref().unwrap(),
&[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
);
pass.set_bind_group(
2,
&cloud_bind_group.cloud_bind_group,
&[]
);
pass.set_bind_group(
3,
&cloud_bind_group.radix_sort_bind_groups[1],
&[],
);
let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap();
pass.set_pipeline(radix_sort_a);
let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a;
pass.dispatch_workgroups((cloud.count + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);
let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
pass.set_pipeline(radix_sort_b);
pass.dispatch_workgroups(1, radix_digit_places, 1);
}
for pass_idx in 0..radix_digit_places {
if pass_idx > 0 {
let size = ShaderDefines::default().radix_base * ShaderDefines::default().max_tile_count_c * std::mem::size_of::<u32>() as u32;
command_encoder.clear_buffer(
&cloud.sorting_global_buffer,
0,
std::num::NonZeroU64::new(size as u64).unwrap().into()
);
}
let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());
let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap();
pass.set_pipeline(&radix_sort_c);
pass.set_bind_group(
0,
&view_bind_group.value,
&[view_uniform_offset.offset],
);
pass.set_bind_group(
1,
gaussian_uniforms.base_bind_group.as_ref().unwrap(),
&[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
);
pass.set_bind_group(
2,
&cloud_bind_group.cloud_bind_group,
&[]
);
pass.set_bind_group(
3,
&cloud_bind_group.radix_sort_bind_groups[pass_idx as usize],
&[],
);
let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c;
pass.dispatch_workgroups(1, (cloud.count + workgroup_entries_c - 1) / workgroup_entries_c, 1);
}
}
}
Ok(())
}
}
bevy_gaussian_splatting/src/sort/std.rs
Line 97 in 6033195
use bevy::{
prelude::*,
asset::LoadState,
utils::Instant,
};
use crate::{
GaussianCloud,
GaussianCloudSettings,
sort::{
SortedEntries,
SortMode,
},
};
#[derive(Default)]
pub struct StdSortPlugin;
impl Plugin for StdSortPlugin {
fn build(&self, app: &mut App) {
app.add_systems(Update, std_sort);
}
}
pub fn std_sort(
asset_server: Res<AssetServer>,
gaussian_clouds_res: Res<Assets<GaussianCloud>>,
mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
gaussian_clouds: Query<(
&Handle<GaussianCloud>,
&Handle<SortedEntries>,
&GaussianCloudSettings,
)>,
cameras: Query<(
&GlobalTransform,
&Camera3d,
)>,
mut last_camera_position: Local<Vec3>,
mut last_sort_time: Local<Option<Instant>>,
) {
let period = std::time::Duration::from_millis(100);
if let Some(last_sort_time) = last_sort_time.as_ref() {
if last_sort_time.elapsed() < period {
return;
}
}
for (
camera_transform,
_camera,
) in cameras.iter() {
let camera_position = camera_transform.compute_transform().translation;
if *last_camera_position == camera_position {
return;
}
for (
gaussian_cloud_handle,
sorted_entries_handle,
settings,
) in gaussian_clouds.iter() {
if settings.sort_mode != SortMode::Std {
continue;
}
if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) {
continue;
}
if Some(LoadState::Loading) == asset_server.get_load_state(sorted_entries_handle) {
continue;
}
if let Some(gaussian_cloud) = gaussian_clouds_res.get(gaussian_cloud_handle) {
if let Some(sorted_entries) = sorted_entries_res.get_mut(sorted_entries_handle) {
assert_eq!(gaussian_cloud.gaussians.len(), sorted_entries.sorted.len());
*last_camera_position = camera_position;
*last_sort_time = Some(Instant::now());
gaussian_cloud.gaussians.iter()
.zip(sorted_entries.sorted.iter_mut())
.enumerate()
.for_each(|(idx, (gaussian, sort_entry))| {
let position = Vec3::from_slice(gaussian.position.as_ref());
let delta = camera_position - position;
sort_entry.key = bytemuck::cast(delta.length_squared());
sort_entry.index = idx as u32;
});
sorted_entries.sorted.sort_unstable_by(|a, b| {
bytemuck::cast::<u32, f32>(b.key).partial_cmp(&bytemuck::cast::<u32, f32>(a.key)).unwrap()
});
// TODO: update DrawIndirect buffer during sort phase (GPU sort will override default DrawIndirect)
}
}
}
}
}
None => return RenderCommandResult::Failure,
};
pass.set_bind_group(2, &bind_groups.cloud_bind_group, &[]);
pass.set_bind_group(3, &bind_groups.sorted_bind_group, &[]);
pass.draw_indirect(&gpu_gaussian_cloud.draw_indirect_buffer, 0);
RenderCommandResult::Success
}
}
struct RadixSortNode {
gaussian_clouds: QueryState<(
&'static Handle<GaussianCloud>,
&'static GaussianCloudBindGroup
)>,
initialized: bool,
pipeline_idx: Option<u32>,
view_bind_group: QueryState<(
&'static GaussianViewBindGroup,
&'static ViewUniformOffset,
)>,
}
impl FromWorld for RadixSortNode {
fn from_world(world: &mut World) -> Self {
Self {
gaussian_clouds: world.query(),
initialized: false,
pipeline_idx: None,
view_bind_group: world.query(),
}
}
}
impl render_graph::Node for RadixSortNode {
fn update(&mut self, world: &mut World) {
let pipeline = world.resource::<GaussianCloudPipeline>();
let pipeline_cache = world.resource::<PipelineCache>();
if !self.initialized {
let mut pipelines_loaded = true;
for sort_pipeline in pipeline.radix_sort_pipelines.iter() {
if let CachedPipelineState::Ok(_) =
pipeline_cache.get_compute_pipeline_state(*sort_pipeline)
{
continue;
}
pipelines_loaded = false;
}
self.initialized = pipelines_loaded;
if !self.initialized {
return;
}
}
if self.pipeline_idx.is_none() {
self.pipeline_idx = Some(0);
} else {
self.pipeline_idx = Some((self.pipeline_idx.unwrap() + 1) % pipeline.radix_sort_pipelines.len() as u32);
}
self.gaussian_clouds.update_archetypes(world);
self.view_bind_group.update_archetypes(world);
}
fn run(
&self,
_graph: &mut render_graph::RenderGraphContext,
render_context: &mut RenderContext,
world: &World,
) -> Result<(), render_graph::NodeRunError> {
if !self.initialized || self.pipeline_idx.is_none() {
return Ok(());
}
let _idx = self.pipeline_idx.unwrap() as usize; // TODO: temporal sort
let pipeline_cache = world.resource::<PipelineCache>();
let pipeline = world.resource::<GaussianCloudPipeline>();
let gaussian_uniforms = world.resource::<GaussianUniformBindGroups>();
let command_encoder = render_context.command_encoder();
for (
view_bind_group,
view_uniform_offset,
) in self.view_bind_group.iter_manual(world) {
for (
cloud_handle,
cloud_bind_group
) in self.gaussian_clouds.iter_manual(world) {
let cloud = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap().get(cloud_handle).unwrap();
let radix_digit_places = ShaderDefines::default().radix_digit_places;
command_encoder.clear_buffer(
&cloud.sorting_global_buffer,
0,
None,
);
{
let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());
// TODO: view/global
pass.set_bind_group(
0,
&view_bind_group.value,
&[view_uniform_offset.offset],
);
pass.set_bind_group(
1,
gaussian_uniforms.base_bind_group.as_ref().unwrap(),
&[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
);
pass.set_bind_group(
2,
&cloud_bind_group.cloud_bind_group,
&[]
);
pass.set_bind_group(
3,
&cloud_bind_group.radix_sort_bind_groups[1],
&[],
);
let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap();
pass.set_pipeline(radix_sort_a);
let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a;
pass.dispatch_workgroups((cloud.count + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);
let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
pass.set_pipeline(radix_sort_b);
pass.dispatch_workgroups(1, radix_digit_places, 1);
}
for pass_idx in 0..radix_digit_places {
if pass_idx > 0 {
let size = ShaderDefines::default().radix_base * ShaderDefines::default().max_tile_count_c * std::mem::size_of::<u32>() as u32;
command_encoder.clear_buffer(
&cloud.sorting_global_buffer,
0,
std::num::NonZeroU64::new(size as u64).unwrap().into()
);
}
let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());
let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap();
pass.set_pipeline(&radix_sort_c);
pass.set_bind_group(
0,
&view_bind_group.value,
&[view_uniform_offset.offset],
);
pass.set_bind_group(
1,
gaussian_uniforms.base_bind_group.as_ref().unwrap(),
&[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
);
pass.set_bind_group(
2,
&cloud_bind_group.cloud_bind_group,
&[]
);
pass.set_bind_group(
3,
&cloud_bind_group.radix_sort_bind_groups[pass_idx as usize],
&[],
);
let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c;
pass.dispatch_workgroups(1, (cloud.count + workgroup_entries_c - 1) / workgroup_entries_c, 1);
}
}
}
Ok(())
}
}
bevy_gaussian_splatting/src/render/mod.rs
Line 186 in 4356f87
) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
let gaussian_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
label: Some("gaussian cloud buffer"),
contents: bytemuck::cast_slice(gaussian_cloud.gaussians.as_slice()),
usage: BufferUsages::VERTEX | BufferUsages::COPY_DST | BufferUsages::STORAGE,
});
let count = gaussian_cloud.gaussians.len();
// TODO: keep draw_indirect at the gaussian cloud level
let draw_indirect_buffer = render_device.create_buffer(&BufferDescriptor {
label: Some("draw indirect buffer"),
size: std::mem::size_of::<wgpu::util::DrawIndirect>() as u64,
bevy_gaussian_splatting/src/io/ply.rs
Line 65 in 507a28d
fn set_property(&mut self, key: String, property: Property) {
match (key.as_ref(), property) {
("x", Property::Float(v)) => self.position_visibility.position[0] = v,
("y", Property::Float(v)) => self.position_visibility.position[1] = v,
("z", Property::Float(v)) => self.position_visibility.position[2] = v,
("f_dc_0", Property::Float(v)) => self.spherical_harmonic.set(0, v),
("f_dc_1", Property::Float(v)) => self.spherical_harmonic.set(1, v),
("f_dc_2", Property::Float(v)) => self.spherical_harmonic.set(2, v),
("scale_0", Property::Float(v)) => self.scale_opacity.scale[0] = v,
("scale_1", Property::Float(v)) => self.scale_opacity.scale[1] = v,
("scale_2", Property::Float(v)) => self.scale_opacity.scale[2] = v,
("opacity", Property::Float(v)) => self.scale_opacity.opacity = 1.0 / (1.0 + (-v).exp()),
("rot_0", Property::Float(v)) => self.rotation.rotation[0] = v,
("rot_1", Property::Float(v)) => self.rotation.rotation[1] = v,
("rot_2", Property::Float(v)) => self.rotation.rotation[2] = v,
("rot_3", Property::Float(v)) => self.rotation.rotation[3] = v,
(_, Property::Float(v)) if key.starts_with("f_rest_") => {
let i = key[7..].parse::<usize>().unwrap();
// interleaved
// if (i + 3) < SH_COEFF_COUNT {
// self.spherical_harmonic.coefficients[i + 3] = v;
// }
// planar
let channel = i / SH_COEFF_COUNT_PER_CHANNEL;
let coefficient = if SH_COEFF_COUNT_PER_CHANNEL == 1 {
1
} else {
(i % (SH_COEFF_COUNT_PER_CHANNEL - 1)) + 1
};
let interleaved_idx = coefficient * SH_CHANNELS + channel;
if interleaved_idx < SH_COEFF_COUNT {
self.spherical_harmonic.set(interleaved_idx, v);
} else {
// TODO: convert higher degree SH to lower degree SH
}
}
(_, _) => {},
None => return RenderCommandResult::Failure,
};
pass.set_bind_group(2, &bind_groups.cloud_bind_group, &[]);
pass.set_bind_group(3, &bind_groups.sorted_bind_group, &[]);
pass.draw_indirect(&gpu_gaussian_cloud.draw_indirect_buffer, 0);
RenderCommandResult::Success
}
}
struct RadixSortNode {
gaussian_clouds: QueryState<(
&'static Handle<GaussianCloud>,
&'static GaussianCloudBindGroup
)>,
initialized: bool,
pipeline_idx: Option<u32>,
view_bind_group: QueryState<(
&'static GaussianViewBindGroup,
&'static ViewUniformOffset,
)>,
}
impl FromWorld for RadixSortNode {
fn from_world(world: &mut World) -> Self {
Self {
gaussian_clouds: world.query(),
initialized: false,
pipeline_idx: None,
view_bind_group: world.query(),
}
}
}
impl render_graph::Node for RadixSortNode {
fn update(&mut self, world: &mut World) {
let pipeline = world.resource::<GaussianCloudPipeline>();
let pipeline_cache = world.resource::<PipelineCache>();
if !self.initialized {
let mut pipelines_loaded = true;
for sort_pipeline in pipeline.radix_sort_pipelines.iter() {
if let CachedPipelineState::Ok(_) =
pipeline_cache.get_compute_pipeline_state(*sort_pipeline)
{
continue;
}
pipelines_loaded = false;
}
self.initialized = pipelines_loaded;
if !self.initialized {
return;
}
}
if self.pipeline_idx.is_none() {
self.pipeline_idx = Some(0);
} else {
self.pipeline_idx = Some((self.pipeline_idx.unwrap() + 1) % pipeline.radix_sort_pipelines.len() as u32);
}
self.gaussian_clouds.update_archetypes(world);
self.view_bind_group.update_archetypes(world);
}
fn run(
&self,
_graph: &mut render_graph::RenderGraphContext,
render_context: &mut RenderContext,
world: &World,
) -> Result<(), render_graph::NodeRunError> {
if !self.initialized || self.pipeline_idx.is_none() {
return Ok(());
}
let _idx = self.pipeline_idx.unwrap() as usize; // TODO: temporal sort
let pipeline_cache = world.resource::<PipelineCache>();
let pipeline = world.resource::<GaussianCloudPipeline>();
let gaussian_uniforms = world.resource::<GaussianUniformBindGroups>();
let command_encoder = render_context.command_encoder();
for (
view_bind_group,
view_uniform_offset,
) in self.view_bind_group.iter_manual(world) {
for (
cloud_handle,
cloud_bind_group
) in self.gaussian_clouds.iter_manual(world) {
let cloud = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap().get(cloud_handle).unwrap();
let radix_digit_places = ShaderDefines::default().radix_digit_places;
command_encoder.clear_buffer(
&cloud.sorting_global_buffer,
0,
None,
);
{
let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());
// TODO: view/global
pass.set_bind_group(
0,
&view_bind_group.value,
&[view_uniform_offset.offset],
);
pass.set_bind_group(
1,
gaussian_uniforms.base_bind_group.as_ref().unwrap(),
&[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
);
pass.set_bind_group(
2,
&cloud_bind_group.cloud_bind_group,
&[]
);
pass.set_bind_group(
3,
&cloud_bind_group.radix_sort_bind_groups[1],
&[],
);
let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap();
pass.set_pipeline(radix_sort_a);
let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a;
pass.dispatch_workgroups((cloud.count + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);
let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
pass.set_pipeline(radix_sort_b);
pass.dispatch_workgroups(1, radix_digit_places, 1);
}
for pass_idx in 0..radix_digit_places {
if pass_idx > 0 {
let size = ShaderDefines::default().radix_base * ShaderDefines::default().max_tile_count_c * std::mem::size_of::<u32>() as u32;
command_encoder.clear_buffer(
&cloud.sorting_global_buffer,
0,
std::num::NonZeroU64::new(size as u64).unwrap().into()
);
}
let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());
let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap();
pass.set_pipeline(&radix_sort_c);
pass.set_bind_group(
0,
&view_bind_group.value,
&[view_uniform_offset.offset],
);
pass.set_bind_group(
1,
gaussian_uniforms.base_bind_group.as_ref().unwrap(),
&[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
);
pass.set_bind_group(
2,
&cloud_bind_group.cloud_bind_group,
&[]
);
pass.set_bind_group(
3,
&cloud_bind_group.radix_sort_bind_groups[pass_idx as usize],
&[],
);
let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c;
pass.dispatch_workgroups(1, (cloud.count + workgroup_entries_c - 1) / workgroup_entries_c, 1);
}
}
}
Ok(())
}
}
use bincode2::serialize_into;
use flate2::{
Compression,
write::GzEncoder,
};
use bevy_gaussian_splatting::{
GaussianCloud,
ply::parse_ply,
};
fn main() {
let filename = std::env::args().nth(1).expect("no filename given");
println!("converting {}", filename);
// filepath to BufRead
let file = std::fs::File::open(&filename).expect("failed to open file");
let mut reader = std::io::BufReader::new(file);
let cloud = GaussianCloud(parse_ply(&mut reader).expect("failed to parse ply file"));
// write cloud to .gcloud file (remove .ply)
let base_filename = filename.split('.').next().expect("no extension").to_string();
let gcloud_filename = base_filename + ".gcloud";
// let gcloud_file = std::fs::File::create(&gcloud_filename).expect("failed to create file");
// let mut writer = std::io::BufWriter::new(gcloud_file);
// serialize_into(&mut writer, &cloud).expect("failed to encode cloud");
// write gloud.gz
let gz_file = std::fs::File::create(&gcloud_filename).expect("failed to create file");
let mut gz_writer = std::io::BufWriter::new(gz_file);
let mut gz_encoder = GzEncoder::new(&mut gz_writer, Compression::default()); // TODO: consider switching to fast (or support multiple options), default is a bit slow
serialize_into(&mut gz_encoder, &cloud).expect("failed to encode cloud");
}
convert storage buffers to https://bevyengine.org/news/bevy-0-12/#gpuarraybuffer
see: https://gist.github.com/cart/3a9f190bd5e789a7d42317c28843ffca
would it be possible to provide scenes/test.ply here?
# - name: radix sort test
# run: cargo run --bin test_radix --features="debug_gpu"
# TODO: test wasm build, deploy, and run
use rand::{
prelude::Distribution,
Rng,
};
use std::marker::Copy;
use bevy::{
prelude::*,
reflect::TypeUuid,
render::render_resource::ShaderType,
};
use bytemuck::{
Pod,
Zeroable,
};
use serde::{
Deserialize,
Serialize,
};
// TODO: add more particle system functionality (e.g. lifetime, color)
#[derive(
Clone,
Debug,
Copy,
PartialEq,
Reflect,
ShaderType,
Pod,
Zeroable,
Serialize,
Deserialize,
)]
#[repr(C)]
pub struct ParticleBehavior {
pub indicies: [u32; 4],
pub velocity: [f32; 4],
pub acceleration: [f32; 4],
pub jerk: [f32; 4],
}
impl Default for ParticleBehavior {
fn default() -> Self {
Self {
indicies: [0, 0, 0, 0],
velocity: [0.0, 0.0, 0.0, 0.0],
acceleration: [0.0, 0.0, 0.0, 0.0],
jerk: [0.0, 0.0, 0.0, 0.0],
}
}
}
#[derive(
Asset,
Clone,
Debug,
Default,
PartialEq,
Reflect,
TypeUuid,
Serialize,
Deserialize,
)]
#[uuid = "ac2f08eb-6463-2131-6772-51571ea332d5"]
pub struct ParticleBehaviors(pub Vec<ParticleBehavior>);
impl Distribution<ParticleBehavior> for rand::distributions::Standard {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> ParticleBehavior {
ParticleBehavior {
acceleration: [
rng.gen_range(-0.01..0.01),
rng.gen_range(-0.01..0.01),
rng.gen_range(-0.01..0.01),
rng.gen_range(-0.01..0.01),
],
jerk: [
rng.gen_range(-0.0001..0.0001),
rng.gen_range(-0.0001..0.0001),
rng.gen_range(-0.0001..0.0001),
rng.gen_range(-0.0001..0.0001),
],
velocity: [
rng.gen_range(-1.0..1.0),
rng.gen_range(-1.0..1.0),
rng.gen_range(-1.0..1.0),
rng.gen_range(-1.0..1.0),
],
..Default::default()
}
}
}
pub fn random_particle_behaviors(n: usize) -> ParticleBehaviors {
let mut rng = rand::thread_rng();
let mut behaviors = Vec::with_capacity(n);
for i in 0..n {
let mut behavior: ParticleBehavior = rng.gen();
behavior.indicies[0] = i as u32;
behaviors.push(behavior);
}
ParticleBehaviors(behaviors)
}
Deserialize,
)]
#[repr(C)]
// TODO: support f16 gaussian clouds (shader and asset loader)
pub struct Gaussian {
pub rotation: [f32; 4],
pub position: Vec3,
supports pre-sorting or CPU sorting in main world, initializes the sorting_entry_buffer
bevy_gaussian_splatting/src/sort/mod.rs
Line 169 in 5c9a20a
use bevy::{
prelude::*,
asset::LoadState,
ecs::system::{
lifetimeless::SRes,
SystemParamItem,
},
reflect::TypeUuid,
render::{
render_asset::{
RenderAsset,
RenderAssetPlugin,
PrepareAssetError,
},
render_resource::*,
renderer::RenderDevice,
},
};
use bytemuck::{
Pod,
Zeroable,
};
use static_assertions::assert_cfg;
use crate::{
GaussianCloud,
GaussianCloudSettings,
};
#[cfg(feature = "sort_radix")]
pub mod radix;
#[cfg(feature = "sort_rayon")]
pub mod rayon;
assert_cfg!(
any(
feature = "sort_radix",
feature = "sort_rayon",
),
"no sort mode enabled",
);
#[derive(
Component,
Debug,
Clone,
PartialEq,
Reflect,
)]
pub enum SortMode {
None,
#[cfg(feature = "sort_radix")]
Radix,
#[cfg(feature = "sort_rayon")]
Rayon,
}
impl Default for SortMode {
#[allow(unreachable_code)]
fn default() -> Self {
#[cfg(feature = "sort_rayon")]
return Self::Rayon;
#[cfg(feature = "sort_radix")]
return Self::Radix;
Self::None
}
}
#[derive(Default)]
pub struct SortPlugin;
impl Plugin for SortPlugin {
fn build(&self, app: &mut App) {
#[cfg(feature = "sort_radix")]
app.add_plugins(radix::RadixSortPlugin);
#[cfg(feature = "sort_rayon")]
app.add_plugins(rayon::RayonSortPlugin);
app.register_type::<SortedEntries>();
app.init_asset::<SortedEntries>();
app.register_asset_reflect::<SortedEntries>();
app.add_plugins(RenderAssetPlugin::<SortedEntries>::default());
app.add_systems(Update, auto_insert_sorted_entries);
}
}
fn auto_insert_sorted_entries(
mut commands: Commands,
asset_server: Res<AssetServer>,
gaussian_clouds_res: Res<Assets<GaussianCloud>>,
mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
gaussian_clouds: Query<(
Entity,
&Handle<GaussianCloud>,
&GaussianCloudSettings,
Without<Handle<SortedEntries>>,
)>,
) {
for (
entity,
gaussian_cloud_handle,
_settings,
_,
) in gaussian_clouds.iter() {
// // TODO: specialize vertex shader for sort mode (e.g. draw_indirect but no sort indirection)
// if settings.sort_mode == SortMode::None {
// continue;
// }
if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) {
continue;
}
let cloud = gaussian_clouds_res.get(gaussian_cloud_handle);
if cloud.is_none() {
continue;
}
let cloud = cloud.unwrap();
// TODO: move gaussian_cloud and sorted_entry assets into an asset bundle
let sorted_entries = sorted_entries_res.add(SortedEntries {
sorted: (0..cloud.gaussians.len())
.map(|idx| {
SortEntry {
key: 1,
index: idx as u32,
}
})
.collect(),
});
commands.entity(entity)
.insert(sorted_entries);
}
}
#[derive(
Clone,
Copy,
Debug,
Default,
PartialEq,
Reflect,
ShaderType,
Pod,
Zeroable,
)]
#[repr(C)]
pub struct SortEntry {
pub key: u32,
pub index: u32,
}
// TODO: add RenderAssetPlugin for SortedEntries & auto-insert to GaussianCloudBundles if their sort mode is not None
// supports pre-sorting or CPU sorting in main world, initializes the sorting_entry_buffer
#[derive(
Clone,
Asset,
Debug,
Default,
PartialEq,
Reflect,
TypeUuid,
)]
#[uuid = "ac2f08eb-fa13-ccdd-ea11-51571ea332d5"]
pub struct SortedEntries {
pub sorted: Vec<SortEntry>,
}
impl RenderAsset for SortedEntries {
type ExtractedAsset = SortedEntries;
type PreparedAsset = GpuSortedEntry;
type Param = SRes<RenderDevice>;
fn extract_asset(&self) -> Self::ExtractedAsset {
self.clone()
}
fn prepare_asset(
sorted_entries: Self::ExtractedAsset,
render_device: &mut SystemParamItem<Self::Param>,
) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
let sorted_entry_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
label: Some("sorted_entry_buffer"),
contents: bytemuck::cast_slice(sorted_entries.sorted.as_slice()),
usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST | BufferUsages::STORAGE,
});
let count = sorted_entries.sorted.len();
Ok(GpuSortedEntry {
sorted_entry_buffer,
count,
})
}
}
// TODO: support instancing and multiple cameras
// separate entry_buffer_a binding into unique a bind group to optimize buffer updates
#[derive(Debug, Clone)]
pub struct GpuSortedEntry {
pub sorted_entry_buffer: Buffer,
pub count: usize,
}
separate entry_buffer_a binding into unique a bind group to optimize buffer updates
bevy_gaussian_splatting/src/sort/mod.rs
Line 214 in 5c9a20a
use bevy::{
prelude::*,
asset::LoadState,
ecs::system::{
lifetimeless::SRes,
SystemParamItem,
},
reflect::TypeUuid,
render::{
render_asset::{
RenderAsset,
RenderAssetPlugin,
PrepareAssetError,
},
render_resource::*,
renderer::RenderDevice,
},
};
use bytemuck::{
Pod,
Zeroable,
};
use static_assertions::assert_cfg;
use crate::{
GaussianCloud,
GaussianCloudSettings,
};
#[cfg(feature = "sort_radix")]
pub mod radix;
#[cfg(feature = "sort_rayon")]
pub mod rayon;
assert_cfg!(
any(
feature = "sort_radix",
feature = "sort_rayon",
),
"no sort mode enabled",
);
#[derive(
Component,
Debug,
Clone,
PartialEq,
Reflect,
)]
pub enum SortMode {
None,
#[cfg(feature = "sort_radix")]
Radix,
#[cfg(feature = "sort_rayon")]
Rayon,
}
impl Default for SortMode {
#[allow(unreachable_code)]
fn default() -> Self {
#[cfg(feature = "sort_rayon")]
return Self::Rayon;
#[cfg(feature = "sort_radix")]
return Self::Radix;
Self::None
}
}
#[derive(Default)]
pub struct SortPlugin;
impl Plugin for SortPlugin {
fn build(&self, app: &mut App) {
#[cfg(feature = "sort_radix")]
app.add_plugins(radix::RadixSortPlugin);
#[cfg(feature = "sort_rayon")]
app.add_plugins(rayon::RayonSortPlugin);
app.register_type::<SortedEntries>();
app.init_asset::<SortedEntries>();
app.register_asset_reflect::<SortedEntries>();
app.add_plugins(RenderAssetPlugin::<SortedEntries>::default());
app.add_systems(Update, auto_insert_sorted_entries);
}
}
fn auto_insert_sorted_entries(
mut commands: Commands,
asset_server: Res<AssetServer>,
gaussian_clouds_res: Res<Assets<GaussianCloud>>,
mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
gaussian_clouds: Query<(
Entity,
&Handle<GaussianCloud>,
&GaussianCloudSettings,
Without<Handle<SortedEntries>>,
)>,
) {
for (
entity,
gaussian_cloud_handle,
_settings,
_,
) in gaussian_clouds.iter() {
// // TODO: specialize vertex shader for sort mode (e.g. draw_indirect but no sort indirection)
// if settings.sort_mode == SortMode::None {
// continue;
// }
if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) {
continue;
}
let cloud = gaussian_clouds_res.get(gaussian_cloud_handle);
if cloud.is_none() {
continue;
}
let cloud = cloud.unwrap();
// TODO: move gaussian_cloud and sorted_entry assets into an asset bundle
let sorted_entries = sorted_entries_res.add(SortedEntries {
sorted: (0..cloud.gaussians.len())
.map(|idx| {
SortEntry {
key: 1,
index: idx as u32,
}
})
.collect(),
});
commands.entity(entity)
.insert(sorted_entries);
}
}
#[derive(
Clone,
Copy,
Debug,
Default,
PartialEq,
Reflect,
ShaderType,
Pod,
Zeroable,
)]
#[repr(C)]
pub struct SortEntry {
pub key: u32,
pub index: u32,
}
// TODO: add RenderAssetPlugin for SortedEntries & auto-insert to GaussianCloudBundles if their sort mode is not None
// supports pre-sorting or CPU sorting in main world, initializes the sorting_entry_buffer
#[derive(
Clone,
Asset,
Debug,
Default,
PartialEq,
Reflect,
TypeUuid,
)]
#[uuid = "ac2f08eb-fa13-ccdd-ea11-51571ea332d5"]
pub struct SortedEntries {
pub sorted: Vec<SortEntry>,
}
impl RenderAsset for SortedEntries {
type ExtractedAsset = SortedEntries;
type PreparedAsset = GpuSortedEntry;
type Param = SRes<RenderDevice>;
fn extract_asset(&self) -> Self::ExtractedAsset {
self.clone()
}
fn prepare_asset(
sorted_entries: Self::ExtractedAsset,
render_device: &mut SystemParamItem<Self::Param>,
) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
let sorted_entry_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
label: Some("sorted_entry_buffer"),
contents: bytemuck::cast_slice(sorted_entries.sorted.as_slice()),
usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST | BufferUsages::STORAGE,
});
let count = sorted_entries.sorted.len();
Ok(GpuSortedEntry {
sorted_entry_buffer,
count,
})
}
}
// TODO: support instancing and multiple cameras
// separate entry_buffer_a binding into unique a bind group to optimize buffer updates
#[derive(Debug, Clone)]
pub struct GpuSortedEntry {
pub sorted_entry_buffer: Buffer,
pub count: usize,
}
bevy_gaussian_splatting/src/sort/std.rs
Line 26 in 507a28d
}
}
// TODO: async CPU sort to prevent frame drops on large clouds
pub fn std_sort(
asset_server: Res<AssetServer>,
gaussian_clouds_res: Res<Assets<GaussianCloud>>,
requires GaussianCloud RenderAsset dependency
use bevy::{
prelude::*,
asset::{
load_internal_asset,
LoadState,
},
core_pipeline::core_3d::CORE_3D,
ecs::system::{
lifetimeless::SRes,
SystemParamItem,
},
render::{
render_asset::RenderAssets,
render_resource::*,
renderer::{
RenderContext,
RenderDevice,
},
render_graph::{
Node,
NodeRunError,
RenderGraphApp,
RenderGraphContext,
},
Render,
RenderApp,
RenderSet,
view::ViewUniformOffset,
},
};
use crate::{
gaussian::GaussianCloud,
render::{
GaussianCloudBindGroup,
GaussianCloudPipeline,
GaussianUniformBindGroups,
GaussianViewBindGroup,
ShaderDefines,
shader_defs,
},
};
const RADIX_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(6234673214);
const TEMPORAL_SORT_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(1634543224);
pub mod node {
pub const RADIX_SORT: &str = "radix_sort";
}
#[derive(Default)]
pub struct RadixSortPlugin;
impl Plugin for RadixSortPlugin {
fn build(&self, app: &mut App) {
load_internal_asset!(
app,
RADIX_SHADER_HANDLE,
"radix.wgsl",
Shader::from_wgsl
);
load_internal_asset!(
app,
TEMPORAL_SORT_SHADER_HANDLE,
"temporal.wgsl",
Shader::from_wgsl
);
if let Ok(render_app) = app.get_sub_app_mut(RenderApp) {
render_app
.add_render_graph_node::<RadixSortNode>(
CORE_3D,
node::RADIX_SORT,
)
.add_render_graph_edge(
CORE_3D,
node::RADIX_SORT,
bevy::core_pipeline::core_3d::graph::node::PREPASS,
);
render_app
.add_systems(
Render,
(
queue_radix_bind_group.in_set(RenderSet::QueueMeshes),
),
);
}
}
fn finish(&self, app: &mut App) {
if let Ok(render_app) = app.get_sub_app_mut(RenderApp) {
render_app
.init_resource::<RadixSortPipeline>();
}
}
}
// TODO: allow swapping of sort backends
// requires GaussianCloud RenderAsset dependency
#[derive(Debug, Clone)]
pub struct GpuRadixBuffers {
pub sorting_global_buffer: Buffer,
pub sorting_status_counter_buffer: Buffer,
pub sorting_pass_buffers: [Buffer; 4],
pub entry_buffer_a: Buffer,
pub entry_buffer_b: Buffer,
}
impl GpuRadixBuffers {
pub fn new(
count: usize,
render_device: &mut SystemParamItem<SRes<RenderDevice>>,
) -> Self {
let sorting_global_buffer = render_device.create_buffer(&BufferDescriptor {
label: Some("sorting global buffer"),
size: ShaderDefines::default().sorting_buffer_size as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let sorting_status_counter_buffer = render_device.create_buffer(&BufferDescriptor {
label: Some("status counters buffer"),
size: ShaderDefines::default().sorting_status_counters_buffer_size(count) as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let sorting_pass_buffers = (0..4)
.map(|idx| {
render_device.create_buffer_with_data(&BufferInitDescriptor {
label: format!("sorting pass buffer {}", idx).as_str().into(),
contents: &[idx as u8, 0, 0, 0],
usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
})
})
.collect::<Vec<Buffer>>()
.try_into()
.unwrap();
let entry_buffer_a = render_device.create_buffer(&BufferDescriptor {
label: Some("entry buffer a"),
size: (count * std::mem::size_of::<(u32, u32)>()) as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let entry_buffer_b = render_device.create_buffer(&BufferDescriptor {
label: Some("entry buffer b"),
size: (count * std::mem::size_of::<(u32, u32)>()) as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
GpuRadixBuffers {
sorting_global_buffer,
sorting_status_counter_buffer,
sorting_pass_buffers,
entry_buffer_a,
entry_buffer_b,
}
}
}
#[derive(Resource)]
pub struct RadixSortPipeline {
pub radix_sort_layout: BindGroupLayout,
pub radix_sort_pipelines: [CachedComputePipelineId; 3],
}
impl FromWorld for RadixSortPipeline {
fn from_world(render_world: &mut World) -> Self {
let render_device = render_world.resource::<RenderDevice>();
let gaussian_cloud_pipeline = render_world.resource::<GaussianCloudPipeline>();
let sorting_buffer_entry = BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(ShaderDefines::default().sorting_buffer_size as u64),
},
count: None,
};
let sorting_status_counters_buffer_entry = BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(ShaderDefines::default().sorting_status_counters_buffer_size(1) as u64),
},
count: None,
};
let draw_indirect_buffer_entry = BindGroupLayoutEntry {
binding: 3,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<wgpu::util::DrawIndirect>() as u64),
},
count: None,
};
let radix_sort_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("radix_sort_layout"),
entries: &[
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<u32>() as u64),
},
count: None,
},
sorting_buffer_entry,
sorting_status_counters_buffer_entry,
draw_indirect_buffer_entry,
BindGroupLayoutEntry {
binding: 4,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64),
},
count: None,
},
BindGroupLayoutEntry {
binding: 5,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64),
},
count: None,
},
],
});
let sorting_layout = vec![
gaussian_cloud_pipeline.view_layout.clone(),
gaussian_cloud_pipeline.gaussian_uniform_layout.clone(),
gaussian_cloud_pipeline.gaussian_cloud_layout.clone(),
radix_sort_layout.clone(),
];
let shader_defs = shader_defs(false, false);
let pipeline_cache = render_world.resource::<PipelineCache>();
let radix_sort_a = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("radix_sort_a".into()),
layout: sorting_layout.clone(),
push_constant_ranges: vec![],
shader: RADIX_SHADER_HANDLE,
shader_defs: shader_defs.clone(),
entry_point: "radix_sort_a".into(),
});
let radix_sort_b = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("radix_sort_b".into()),
layout: sorting_layout.clone(),
push_constant_ranges: vec![],
shader: RADIX_SHADER_HANDLE,
shader_defs: shader_defs.clone(),
entry_point: "radix_sort_b".into(),
});
let radix_sort_c = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("radix_sort_c".into()),
layout: sorting_layout.clone(),
push_constant_ranges: vec![],
shader: RADIX_SHADER_HANDLE,
shader_defs: shader_defs.clone(),
entry_point: "radix_sort_c".into(),
});
RadixSortPipeline {
radix_sort_layout,
radix_sort_pipelines: [
radix_sort_a,
radix_sort_b,
radix_sort_c,
],
}
}
}
#[derive(Component)]
pub struct RadixBindGroup {
pub radix_sort_bind_groups: [BindGroup; 4],
}
pub fn queue_radix_bind_group(
mut commands: Commands,
radix_pipeline: Res<RadixSortPipeline>,
render_device: Res<RenderDevice>,
asset_server: Res<AssetServer>,
gaussian_cloud_res: Res<RenderAssets<GaussianCloud>>,
gaussian_clouds: Query<(
Entity,
&Handle<GaussianCloud>,
)>,
) {
for (entity, cloud_handle) in gaussian_clouds.iter() {
if Some(LoadState::Loading) == asset_server.get_load_state(cloud_handle) {
continue;
}
if gaussian_cloud_res.get(cloud_handle).is_none() {
continue;
}
let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();
let sorting_global_entry = BindGroupEntry {
binding: 1,
resource: BindingResource::Buffer(BufferBinding {
buffer: &cloud.radix_sort_buffers.sorting_global_buffer,
offset: 0,
size: BufferSize::new(cloud.radix_sort_buffers.sorting_global_buffer.size()),
}),
};
let sorting_status_counters_entry = BindGroupEntry {
binding: 2,
resource: BindingResource::Buffer(BufferBinding {
buffer: &cloud.radix_sort_buffers.sorting_status_counter_buffer,
offset: 0,
size: BufferSize::new(cloud.radix_sort_buffers.sorting_status_counter_buffer.size()),
}),
};
let draw_indirect_entry = BindGroupEntry {
binding: 3,
resource: BindingResource::Buffer(BufferBinding {
buffer: &cloud.draw_indirect_buffer,
offset: 0,
size: BufferSize::new(cloud.draw_indirect_buffer.size()),
}),
};
let radix_sort_bind_groups: [BindGroup; 4] = (0..4)
.map(|idx| {
render_device.create_bind_group(
format!("radix_sort_bind_group {}", idx).as_str(),
&radix_pipeline.radix_sort_layout,
&[
BindGroupEntry {
binding: 0,
resource: BindingResource::Buffer(BufferBinding {
buffer: &cloud.radix_sort_buffers.sorting_pass_buffers[idx],
offset: 0,
size: BufferSize::new(std::mem::size_of::<u32>() as u64),
}),
},
sorting_global_entry.clone(),
sorting_status_counters_entry.clone(),
draw_indirect_entry.clone(),
BindGroupEntry {
binding: 4,
resource: BindingResource::Buffer(BufferBinding {
buffer: if idx % 2 == 0 {
&cloud.radix_sort_buffers.entry_buffer_a
} else {
&cloud.radix_sort_buffers.entry_buffer_b
},
offset: 0,
size: BufferSize::new((cloud.count as usize * std::mem::size_of::<(u32, u32)>()) as u64),
}),
},
BindGroupEntry {
binding: 5,
resource: BindingResource::Buffer(BufferBinding {
buffer: if idx % 2 == 0 {
&cloud.radix_sort_buffers.entry_buffer_b
} else {
&cloud.radix_sort_buffers.entry_buffer_a
},
offset: 0,
size: BufferSize::new((cloud.count as usize * std::mem::size_of::<(u32, u32)>()) as u64),
}),
},
],
)
})
.collect::<Vec<BindGroup>>()
.try_into()
.unwrap();
commands.entity(entity).insert(RadixBindGroup {
radix_sort_bind_groups,
});
}
}
pub struct RadixSortNode {
gaussian_clouds: QueryState<(
&'static Handle<GaussianCloud>,
&'static GaussianCloudBindGroup,
&'static RadixBindGroup,
)>,
initialized: bool,
view_bind_group: QueryState<(
&'static GaussianViewBindGroup,
&'static ViewUniformOffset,
)>,
}
impl FromWorld for RadixSortNode {
fn from_world(world: &mut World) -> Self {
Self {
gaussian_clouds: world.query(),
initialized: false,
view_bind_group: world.query(),
}
}
}
impl Node for RadixSortNode {
fn update(&mut self, world: &mut World) {
let pipeline = world.resource::<RadixSortPipeline>();
let pipeline_cache = world.resource::<PipelineCache>();
if !self.initialized {
let mut pipelines_loaded = true;
for sort_pipeline in pipeline.radix_sort_pipelines.iter() {
if let CachedPipelineState::Ok(_) =
pipeline_cache.get_compute_pipeline_state(*sort_pipeline)
{
continue;
}
pipelines_loaded = false;
}
self.initialized = pipelines_loaded;
if !self.initialized {
return;
}
}
self.gaussian_clouds.update_archetypes(world);
self.view_bind_group.update_archetypes(world);
}
fn run(
&self,
_graph: &mut RenderGraphContext,
render_context: &mut RenderContext,
world: &World,
) -> Result<(), NodeRunError> {
if !self.initialized {
return Ok(());
}
let pipeline_cache = world.resource::<PipelineCache>();
let pipeline = world.resource::<RadixSortPipeline>();
let gaussian_uniforms = world.resource::<GaussianUniformBindGroups>();
let command_encoder = render_context.command_encoder();
for (
view_bind_group,
view_uniform_offset,
) in self.view_bind_group.iter_manual(world) {
for (
cloud_handle,
cloud_bind_group,
radix_bind_group,
) in self.gaussian_clouds.iter_manual(world) {
let cloud = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap().get(cloud_handle).unwrap();
let radix_digit_places = ShaderDefines::default().radix_digit_places;
{
command_encoder.clear_buffer(
&cloud.radix_sort_buffers.sorting_global_buffer,
0,
None,
);
command_encoder.clear_buffer(
&cloud.radix_sort_buffers.sorting_status_counter_buffer,
0,
None,
);
command_encoder.clear_buffer(
&cloud.draw_indirect_buffer,
0,
None,
);
}
{
let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());
pass.set_bind_group(
0,
&view_bind_group.value,
&[view_uniform_offset.offset],
);
pass.set_bind_group(
1,
gaussian_uniforms.base_bind_group.as_ref().unwrap(),
&[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
);
pass.set_bind_group(
2,
&cloud_bind_group.cloud_bind_group,
&[]
);
pass.set_bind_group(
3,
&radix_bind_group.radix_sort_bind_groups[1],
&[],
);
let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap();
pass.set_pipeline(radix_sort_a);
let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a;
pass.dispatch_workgroups((cloud.count as u32 + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);
let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
pass.set_pipeline(radix_sort_b);
pass.dispatch_workgroups(1, radix_digit_places, 1);
}
for pass_idx in 0..radix_digit_places {
if pass_idx > 0 {
command_encoder.clear_buffer(
&cloud.radix_sort_buffers.sorting_status_counter_buffer,
0,
None,
);
}
let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());
let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap();
pass.set_pipeline(&radix_sort_c);
pass.set_bind_group(
0,
&view_bind_group.value,
&[view_uniform_offset.offset],
);
pass.set_bind_group(
1,
gaussian_uniforms.base_bind_group.as_ref().unwrap(),
&[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
);
pass.set_bind_group(
2,
&cloud_bind_group.cloud_bind_group,
&[]
);
pass.set_bind_group(
3,
&radix_bind_group.radix_sort_bind_groups[pass_idx as usize],
&[],
);
let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c;
pass.dispatch_workgroups(1, (cloud.count as u32 + workgroup_entries_c - 1) / workgroup_entries_c, 1);
}
}
}
Ok(())
}
}
use bevy::{
prelude::*,
app::AppExit,
core::Name,
};
use bevy_inspector_egui::quick::WorldInspectorPlugin;
use bevy_panorbit_camera::{
PanOrbitCamera,
PanOrbitCameraPlugin,
};
use bevy_gaussian_splatting::{
Gaussian,
GaussianCloud,
GaussianCloudSettings,
GaussianSplattingBundle,
GaussianSplattingPlugin,
utils::setup_hooks, SphericalHarmonicCoefficients,
};
// TODO: move to editor crate
pub struct GaussianSplattingViewer {
pub editor: bool,
pub esc_close: bool,
pub show_fps: bool,
pub width: f32,
pub height: f32,
pub name: String,
}
impl Default for GaussianSplattingViewer {
fn default() -> GaussianSplattingViewer {
GaussianSplattingViewer {
editor: true,
esc_close: true,
show_fps: true,
width: 1920.0,
height: 1080.0,
name: "bevy_gaussian_splatting".to_string(),
}
}
}
pub fn setup_aabb_obb_compare(
mut commands: Commands,
mut gaussian_assets: ResMut<Assets<GaussianCloud>>,
) {
let mut blue_sh = SphericalHarmonicCoefficients::default();
blue_sh.coefficients[2] = 5.0;
let blue_aabb_gaussian = Gaussian {
position: [0.0, 0.0, 0.0, 1.0],
rotation: [0.89, 0.0, -0.432, 0.144],
scale_opacity: [10.0, 1.0, 1.0, 0.5],
spherical_harmonic: blue_sh,
};
commands.spawn((
GaussianSplattingBundle {
cloud: gaussian_assets.add(
GaussianCloud(vec![
blue_aabb_gaussian,
blue_aabb_gaussian,
])
),
settings: GaussianCloudSettings {
aabb: true,
visualize_bounding_box: true,
..default()
},
..default()
},
Name::new("gaussian_cloud_aabb"),
));
let mut red_sh = SphericalHarmonicCoefficients::default();
red_sh.coefficients[0] = 5.0;
let red_obb_gaussian = Gaussian {
position: [0.0, 0.0, 0.0, 1.0],
rotation: [0.89, 0.0, -0.432, 0.144],
scale_opacity: [10.0, 1.0, 1.0, 0.5],
spherical_harmonic: red_sh,
};
commands.spawn((
GaussianSplattingBundle {
cloud: gaussian_assets.add(
GaussianCloud(vec![
red_obb_gaussian,
red_obb_gaussian,
])
),
settings: GaussianCloudSettings {
aabb: false,
visualize_bounding_box: true,
..default()
},
..default()
},
Name::new("gaussian_cloud_obb"),
));
commands.spawn((
Camera3dBundle {
transform: Transform::from_translation(Vec3::new(0.0, 1.5, 5.0)),
..default()
},
PanOrbitCamera{
allow_upside_down: true,
..default()
},
));
}
fn compare_aabb_obb_app() {
let config = GaussianSplattingViewer::default();
let mut app = App::new();
// setup for gaussian viewer app
app.insert_resource(ClearColor(Color::rgb_u8(0, 0, 0)));
app.add_plugins(
DefaultPlugins
.set(ImagePlugin::default_nearest())
.set(WindowPlugin {
primary_window: Some(Window {
fit_canvas_to_parent: false,
mode: bevy::window::WindowMode::Windowed,
present_mode: bevy::window::PresentMode::AutoVsync,
prevent_default_event_handling: false,
resolution: (config.width, config.height).into(),
title: config.name.clone(),
..default()
}),
..default()
}),
);
app.add_plugins((
PanOrbitCameraPlugin,
));
if config.editor {
app.add_plugins(WorldInspectorPlugin::new());
}
if config.esc_close {
app.add_systems(Update, esc_close);
}
// setup for gaussian splatting
app.add_plugins(GaussianSplattingPlugin);
app.add_systems(Startup, setup_aabb_obb_compare);
app.run();
}
pub fn esc_close(
keys: Res<Input<KeyCode>>,
mut exit: EventWriter<AppExit>
) {
if keys.just_pressed(KeyCode::Escape) {
exit.send(AppExit);
}
}
pub fn main() {
setup_hooks();
compare_aabb_obb_app();
}
// TODO: minimal app
fn main() {
println!("Hello, world!");
}
None => return RenderCommandResult::Failure,
};
pass.set_bind_group(2, &bind_groups.cloud_bind_group, &[]);
pass.set_bind_group(3, &bind_groups.sorted_bind_group, &[]);
pass.draw_indirect(&gpu_gaussian_cloud.draw_indirect_buffer, 0);
RenderCommandResult::Success
}
}
struct RadixSortNode {
gaussian_clouds: QueryState<(
&'static Handle<GaussianCloud>,
&'static GaussianCloudBindGroup
)>,
initialized: bool,
pipeline_idx: Option<u32>,
view_bind_group: QueryState<(
&'static GaussianViewBindGroup,
&'static ViewUniformOffset,
)>,
}
impl FromWorld for RadixSortNode {
fn from_world(world: &mut World) -> Self {
Self {
gaussian_clouds: world.query(),
initialized: false,
pipeline_idx: None,
view_bind_group: world.query(),
}
}
}
impl render_graph::Node for RadixSortNode {
fn update(&mut self, world: &mut World) {
let pipeline = world.resource::<GaussianCloudPipeline>();
let pipeline_cache = world.resource::<PipelineCache>();
if !self.initialized {
let mut pipelines_loaded = true;
for sort_pipeline in pipeline.radix_sort_pipelines.iter() {
if let CachedPipelineState::Ok(_) =
pipeline_cache.get_compute_pipeline_state(*sort_pipeline)
{
continue;
}
pipelines_loaded = false;
}
self.initialized = pipelines_loaded;
if !self.initialized {
return;
}
}
if self.pipeline_idx.is_none() {
self.pipeline_idx = Some(0);
} else {
self.pipeline_idx = Some((self.pipeline_idx.unwrap() + 1) % pipeline.radix_sort_pipelines.len() as u32);
}
self.gaussian_clouds.update_archetypes(world);
self.view_bind_group.update_archetypes(world);
}
fn run(
&self,
_graph: &mut render_graph::RenderGraphContext,
render_context: &mut RenderContext,
world: &World,
) -> Result<(), render_graph::NodeRunError> {
if !self.initialized || self.pipeline_idx.is_none() {
return Ok(());
}
let _idx = self.pipeline_idx.unwrap() as usize; // TODO: temporal sort
let pipeline_cache = world.resource::<PipelineCache>();
let pipeline = world.resource::<GaussianCloudPipeline>();
let gaussian_uniforms = world.resource::<GaussianUniformBindGroups>();
let command_encoder = render_context.command_encoder();
for (
view_bind_group,
view_uniform_offset,
) in self.view_bind_group.iter_manual(world) {
for (
cloud_handle,
cloud_bind_group
) in self.gaussian_clouds.iter_manual(world) {
let cloud = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap().get(cloud_handle).unwrap();
let radix_digit_places = ShaderDefines::default().radix_digit_places;
command_encoder.clear_buffer(
&cloud.sorting_global_buffer,
0,
None,
);
{
let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());
// TODO: view/global
pass.set_bind_group(
0,
&view_bind_group.value,
&[view_uniform_offset.offset],
);
pass.set_bind_group(
1,
gaussian_uniforms.base_bind_group.as_ref().unwrap(),
&[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
);
pass.set_bind_group(
2,
&cloud_bind_group.cloud_bind_group,
&[]
);
pass.set_bind_group(
3,
&cloud_bind_group.radix_sort_bind_groups[1],
&[],
);
let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap();
pass.set_pipeline(radix_sort_a);
let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a;
pass.dispatch_workgroups((cloud.count + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);
let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
pass.set_pipeline(radix_sort_b);
pass.dispatch_workgroups(1, radix_digit_places, 1);
}
for pass_idx in 0..radix_digit_places {
if pass_idx > 0 {
let size = ShaderDefines::default().radix_base * ShaderDefines::default().max_tile_count_c * std::mem::size_of::<u32>() as u32;
command_encoder.clear_buffer(
&cloud.sorting_global_buffer,
0,
std::num::NonZeroU64::new(size as u64).unwrap().into()
);
}
let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());
let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap();
pass.set_pipeline(&radix_sort_c);
pass.set_bind_group(
0,
&view_bind_group.value,
&[view_uniform_offset.offset],
);
pass.set_bind_group(
1,
gaussian_uniforms.base_bind_group.as_ref().unwrap(),
&[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
);
pass.set_bind_group(
2,
&cloud_bind_group.cloud_bind_group,
&[]
);
pass.set_bind_group(
3,
&cloud_bind_group.radix_sort_bind_groups[pass_idx as usize],
&[],
);
let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c;
pass.dispatch_workgroups(1, (cloud.count + workgroup_entries_c - 1) / workgroup_entries_c, 1);
}
}
}
Ok(())
}
}
// TODO: support streamed codecs
pub trait GaussianCloudCodec {
fn encode(&self) -> Vec<u8>;
fn decode(data: &[u8]) -> Self;
}
bevy_gaussian_splatting/src/sort/rayon.rs
Line 51 in 5c9a20a
use bevy::{
prelude::*,
asset::LoadState,
utils::Instant,
};
use rayon::prelude::*;
use crate::{
GaussianCloud,
GaussianCloudSettings,
sort::{
SortedEntries,
SortMode,
},
};
#[derive(Default)]
pub struct RayonSortPlugin;
impl Plugin for RayonSortPlugin {
fn build(&self, app: &mut App) {
app.add_systems(Update, rayon_sort);
}
}
pub fn rayon_sort(
asset_server: Res<AssetServer>,
gaussian_clouds_res: Res<Assets<GaussianCloud>>,
mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
gaussian_clouds: Query<(
&Handle<GaussianCloud>,
&Handle<SortedEntries>,
&GaussianCloudSettings,
)>,
cameras: Query<(
&GlobalTransform,
&Camera3d,
)>,
mut last_camera_position: Local<Vec3>,
mut last_sort_time: Local<Option<Instant>>,
) {
let period = std::time::Duration::from_millis(100);
if let Some(last_sort_time) = last_sort_time.as_ref() {
if last_sort_time.elapsed() < period {
return;
}
}
// TODO: move sort to render world, use extracted views and update the existing buffer instead of creating new
for (
camera_transform,
_camera,
) in cameras.iter() {
let camera_position = camera_transform.compute_transform().translation;
if *last_camera_position == camera_position {
return;
}
for (
gaussian_cloud_handle,
sorted_entries_handle,
settings,
) in gaussian_clouds.iter() {
if settings.sort_mode != SortMode::Rayon {
continue;
}
if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) {
continue;
}
if Some(LoadState::Loading) == asset_server.get_load_state(sorted_entries_handle) {
continue;
}
if let Some(gaussian_cloud) = gaussian_clouds_res.get(gaussian_cloud_handle) {
if let Some(sorted_entries) = sorted_entries_res.get_mut(sorted_entries_handle) {
assert_eq!(gaussian_cloud.gaussians.len(), sorted_entries.sorted.len());
*last_camera_position = camera_position;
*last_sort_time = Some(Instant::now());
gaussian_cloud.gaussians.par_iter()
.zip(sorted_entries.sorted.par_iter_mut())
.enumerate()
.for_each(|(idx, (gaussian, sort_entry))| {
let position = Vec3::from_slice(gaussian.position.as_ref());
let delta = camera_position - position;
sort_entry.key = bytemuck::cast(delta.length_squared());
sort_entry.index = idx as u32;
});
sorted_entries.sorted.par_sort_unstable_by(|a, b| {
bytemuck::cast::<u32, f32>(b.key).partial_cmp(&bytemuck::cast::<u32, f32>(a.key)).unwrap()
});
// TODO: update DrawIndirect buffer during sort phase (GPU sort will override default DrawIndirect)
}
}
}
}
}
bevy_gaussian_splatting/viewer/viewer.rs
Line 61 in cd9fd07
..default()
};
// TODO: add proper GaussianSplattingViewer argument parsing
let file_arg = std::env::args().nth(1);
if let Some(n) = file_arg.clone().and_then(|s| s.parse::<usize>().ok()) {
println!("generating {} gaussians", n);
cloud = gaussian_assets.add(random_gaussians(n));
use bevy::{
prelude::*,
asset::LoadState,
ecs::query::QueryItem,
render::{
extract_component::{
ExtractComponent,
ExtractComponentPlugin,
},
Render,
RenderApp,
RenderSet,
render_asset::RenderAssets,
render_resource::{
BindGroup,
BindGroupLayout,
BindGroupLayoutDescriptor,
BindGroupLayoutEntry,
BindGroupEntry,
BindingType,
BindingResource,
Extent3d,
TextureDimension,
TextureFormat,
TextureSampleType,
TextureUsages,
TextureViewDimension,
ShaderStages,
},
renderer::RenderDevice,
},
};
use static_assertions::assert_cfg;
#[allow(unused_imports)]
use crate::{
gaussian::{
cloud::GaussianCloud,
f32::{
PositionVisibility,
Rotation,
ScaleOpacity,
},
},
material::spherical_harmonics::{
SH_COEFF_COUNT,
SH_VEC4_PLANES,
SphericalHarmonicCoefficients,
},
render::{
GaussianCloudPipeline,
GpuGaussianCloud,
},
};
// TODO: support loading from directory of images
assert_cfg!(
feature = "planar",
"texture rendering is only supported with the `planar` feature enabled",
);
assert_cfg!(
not(feature = "f32"),
"f32 texture support is not implemented yet",
);
#[derive(Component, Clone, Debug, Reflect)]
pub struct TextureBuffers {
position_visibility: Handle<Image>,
spherical_harmonics: Handle<Image>,
#[cfg(feature = "f16")]
rotation_scale_opacity: Handle<Image>,
#[cfg(feature = "f32")]
rotation: Handle<Image>,
#[cfg(feature = "f32")]
scale_opacity: Handle<Image>,
}
impl ExtractComponent for TextureBuffers {
type Query = &'static Self;
type Filter = ();
type Out = Self;
fn extract_component(texture_buffers: QueryItem<'_, Self::Query>) -> Option<Self::Out> {
texture_buffers.clone().into()
}
}
#[derive(Default)]
pub struct BufferTexturePlugin;
impl Plugin for BufferTexturePlugin {
fn build(&self, app: &mut App) {
app.register_type::<TextureBuffers>();
app.add_plugins(ExtractComponentPlugin::<TextureBuffers>::default());
app.add_systems(Update, queue_textures);
let render_app = app.sub_app_mut(RenderApp);
render_app.add_systems(
Render,
queue_gpu_texture_buffers.in_set(RenderSet::PrepareAssets),
);
}
}
#[derive(Component, Clone, Debug)]
pub struct GpuTextureBuffers {
pub bind_group: BindGroup,
}
pub fn queue_gpu_texture_buffers(
mut commands: Commands,
// gaussian_cloud_pipeline: Res<GaussianCloudPipeline>,
pipeline: ResMut<GaussianCloudPipeline>,
render_device: ResMut<RenderDevice>,
gpu_images: Res<RenderAssets<Image>>,
clouds: Query<(
Entity,
&TextureBuffers,
)>,
) {
for (entity, texture_buffers,) in clouds.iter() {
#[cfg(feature = "f16")]
let bind_group = render_device.create_bind_group(
Some("texture_gaussian_cloud_bind_group"),
&pipeline.gaussian_cloud_layout,
&[
BindGroupEntry {
binding: 0,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.position_visibility).unwrap().texture_view
),
},
BindGroupEntry {
binding: 1,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.spherical_harmonics).unwrap().texture_view
),
},
BindGroupEntry {
binding: 2,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.rotation_scale_opacity).unwrap().texture_view
),
},
],
);
#[cfg(feature = "f32")]
let bind_group = render_device.create_bind_group(
Some("texture_gaussian_cloud_bind_group"),
&pipeline.gaussian_cloud_layout,
&[
BindGroupEntry {
binding: 0,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.position_visibility).unwrap().texture_view
),
},
BindGroupEntry {
binding: 1,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.spherical_harmonics).unwrap().texture_view
),
},
BindGroupEntry {
binding: 2,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.rotation).unwrap().texture_view
),
},
BindGroupEntry {
binding: 3,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.scale_opacity).unwrap().texture_view
),
},
],
);
commands.entity(entity).insert(GpuTextureBuffers { bind_group });
}
}
// TODO: support asset change detection and reupload
fn queue_textures(
mut commands: Commands,
asset_server: Res<AssetServer>,
gaussian_cloud_res: Res<Assets<GaussianCloud>>,
mut images: ResMut<Assets<Image>>,
clouds: Query<(
Entity,
&Handle<GaussianCloud>,
Without<TextureBuffers>,
)>,
) {
for (entity, cloud_handle, _) in clouds.iter() {
if Some(LoadState::Loading) == asset_server.get_load_state(cloud_handle){
continue;
}
if gaussian_cloud_res.get(cloud_handle).is_none() {
continue;
}
let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();
let square = cloud.len_sqrt_ceil() as u32;
let extent_1d = Extent3d {
width: square,
height: square, // TODO: shrink height to save memory (consider fixed width)
depth_or_array_layers: 1,
};
let mut position_visibility = Image::new(
extent_1d,
TextureDimension::D2,
bytemuck::cast_slice(cloud.position_visibility.as_slice()).to_vec(),
TextureFormat::Rgba32Float,
);
position_visibility.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
let position_visibility = images.add(position_visibility);
let texture_buffers: TextureBuffers;
#[cfg(feature = "f16")]
{
let planar_spherical_harmonics: Vec<u32> = (0..SH_VEC4_PLANES)
.flat_map(|plane_index| {
cloud.spherical_harmonic.iter()
.flat_map(move |sh| {
let start_index = plane_index * 4;
let end_index = std::cmp::min(start_index + 4, sh.coefficients.len());
let mut depthwise = sh.coefficients[start_index..end_index].to_vec();
depthwise.resize(4, 0);
depthwise
})
})
.collect();
let mut spherical_harmonics = Image::new(
Extent3d {
width: square,
height: square,
depth_or_array_layers: SH_VEC4_PLANES as u32,
},
TextureDimension::D2,
bytemuck::cast_slice(planar_spherical_harmonics.as_slice()).to_vec(),
TextureFormat::Rgba32Uint,
);
spherical_harmonics.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
let spherical_harmonics = images.add(spherical_harmonics);
let mut rotation_scale_opacity = Image::new(
extent_1d,
TextureDimension::D2,
bytemuck::cast_slice(cloud.rotation_scale_opacity_packed128.as_slice()).to_vec(),
TextureFormat::Rgba32Uint,
);
rotation_scale_opacity.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
let rotation_scale_opacity = images.add(rotation_scale_opacity);
texture_buffers = TextureBuffers {
position_visibility,
spherical_harmonics,
rotation_scale_opacity,
};
}
#[cfg(feature = "f32")]
{
texture_buffers = TextureBuffers {
position_visibility,
spherical_harmonics: todo!(),
rotation: todo!(),
scale_opacity: todo!(),
};
}
commands.entity(entity).insert(texture_buffers);
}
}
pub fn get_sorted_bind_group_layout(
render_device: &RenderDevice,
) -> BindGroupLayout {
render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("texture_sorted_layout"),
entries: &[
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::all(),
ty: BindingType::Texture {
view_dimension: TextureViewDimension::D2,
sample_type: TextureSampleType::Uint,
multisampled: false,
},
count: None,
},
],
})
}
#[cfg(feature = "f16")]
pub fn get_bind_group_layout(
render_device: &RenderDevice,
_read_only: bool
) -> BindGroupLayout {
let sh_view_dimension = if SH_VEC4_PLANES == 1 {
TextureViewDimension::D2
} else {
TextureViewDimension::D2Array
};
render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("texture_f16_gaussian_cloud_layout"),
entries: &[
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::all(),
ty: BindingType::Texture {
view_dimension: TextureViewDimension::D2,
sample_type: TextureSampleType::Float {
filterable: false,
},
multisampled: false,
},
count: None,
},
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::all(),
ty: BindingType::Texture {
view_dimension: sh_view_dimension,
sample_type: TextureSampleType::Uint,
multisampled: false,
},
count: None,
},
BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::all(),
ty: BindingType::Texture {
view_dimension: TextureViewDimension::D2,
sample_type: TextureSampleType::Uint,
multisampled: false,
},
count: None,
},
],
})
}
#[cfg(feature = "f32")]
pub fn get_bind_group_layout(
render_device: &RenderDevice,
read_only: bool
) -> BindGroupLayout {
render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("texture_f32_gaussian_cloud_layout"),
entries: &[
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::all(),
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<PositionVisibility>() as u64),
},
count: None,
},
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::all(),
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<SphericalHarmonicCoefficients>() as u64),
},
count: None,
},
BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::all(),
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<Rotation>() as u64),
},
count: None,
},
BindGroupLayoutEntry {
binding: 3,
visibility: ShaderStages::all(),
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<ScaleOpacity>() as u64),
},
count: None,
},
],
})
}
],
});
let sorting_buffer_entry = BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(ShaderDefines::default().sorting_buffer_size as u64),
},
count: None,
};
let draw_indirect_buffer_entry = BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<wgpu::util::DrawIndirect>() as u64),
},
count: None,
};
let radix_sort_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("radix_sort_layout"),
entries: &[
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<u32>() as u64),
},
count: None,
},
sorting_buffer_entry,
draw_indirect_buffer_entry,
BindGroupLayoutEntry {
binding: 3,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64),
},
count: None,
},
BindGroupLayoutEntry {
binding: 4,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64),
},
count: None,
},
],
});
let sorted_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("sorted_layout"),
entries: &vec![
BindGroupLayoutEntry {
binding: 5,
visibility: ShaderStages::VERTEX_FRAGMENT,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64),
},
count: None,
},
],
});
let compute_layout = vec![
view_layout.clone(),
gaussian_uniform_layout.clone(),
gaussian_cloud_layout.clone(),
radix_sort_layout.clone(),
];
let shader = GAUSSIAN_SHADER_HANDLE.typed();
let shader_defs = shader_defs(false, false);
let pipeline_cache = render_world.resource::<PipelineCache>();
let radix_sort_a = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("radix_sort_a".into()),
layout: compute_layout.clone(),
push_constant_ranges: vec![],
shader: shader.clone(),
shader_defs: shader_defs.clone(),
entry_point: "radix_sort_a".into(),
});
let radix_sort_b = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("radix_sort_b".into()),
layout: compute_layout.clone(),
push_constant_ranges: vec![],
shader: shader.clone(),
shader_defs: shader_defs.clone(),
entry_point: "radix_sort_b".into(),
});
let radix_sort_c = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("radix_sort_c".into()),
layout: compute_layout.clone(),
push_constant_ranges: vec![],
shader: shader.clone(),
shader_defs: shader_defs.clone(),
entry_point: "radix_sort_c".into(),
});
let temporal_sort_flip = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("temporal_sort_flip".into()),
layout: compute_layout.clone(),
push_constant_ranges: vec![],
shader: shader.clone(),
shader_defs: shader_defs.clone(),
entry_point: "temporal_sort_flip".into(),
});
let temporal_sort_flop = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("temporal_sort_flop".into()),
layout: compute_layout.clone(),
push_constant_ranges: vec![],
shader: shader.clone(),
shader_defs: shader_defs.clone(),
entry_point: "temporal_sort_flop".into(),
});
GaussianCloudPipeline {
gaussian_cloud_layout,
gaussian_uniform_layout,
view_layout,
shader: shader.clone(),
radix_sort_layout,
radix_sort_pipelines: [
radix_sort_a,
radix_sort_b,
radix_sort_c,
],
temporal_sort_pipelines: [
temporal_sort_flip,
temporal_sort_flop,
],
sorted_layout,
}
}
}
// TODO: allow setting shader defines via API
struct ShaderDefines {
radix_bits_per_digit: u32,
radix_digit_places: u32,
radix_base: u32,
entries_per_invocation_a: u32,
entries_per_invocation_c: u32,
workgroup_invocations_a: u32,
workgroup_invocations_c: u32,
workgroup_entries_a: u32,
workgroup_entries_c: u32,
max_tile_count_c: u32,
sorting_buffer_size: usize,
temporal_sort_window_size: u32,
}
impl Default for ShaderDefines {
fn default() -> Self {
let radix_bits_per_digit = 8;
let radix_digit_places = 32 / radix_bits_per_digit;
let radix_base = 1 << radix_bits_per_digit;
let entries_per_invocation_a = 4;
let entries_per_invocation_c = 4;
let workgroup_invocations_a = radix_base * radix_digit_places;
let workgroup_invocations_c = radix_base;
let workgroup_entries_a = workgroup_invocations_a * entries_per_invocation_a;
let workgroup_entries_c = workgroup_invocations_c * entries_per_invocation_c;
let max_tile_count_c = (10000000 + workgroup_entries_c - 1) / workgroup_entries_c;
let sorting_buffer_size = (
radix_base as usize *
(radix_digit_places as usize + max_tile_count_c as usize) *
std::mem::size_of::<u32>()
) + std::mem::size_of::<u32>() * 5;
Self {
radix_bits_per_digit,
radix_digit_places,
radix_base,
entries_per_invocation_a,
entries_per_invocation_c,
workgroup_invocations_a,
workgroup_invocations_c,
workgroup_entries_a,
workgroup_entries_c,
max_tile_count_c,
sorting_buffer_size,
temporal_sort_window_size: 16,
}
}
}
fn shader_defs(
aabb: bool,
visualize_bounding_box: bool,
) -> Vec<ShaderDefVal> {
let defines = ShaderDefines::default();
let mut shader_defs = vec![
ShaderDefVal::UInt("MAX_SH_COEFF_COUNT".into(), MAX_SH_COEFF_COUNT as u32),
ShaderDefVal::UInt("RADIX_BASE".into(), defines.radix_base),
ShaderDefVal::UInt("RADIX_BITS_PER_DIGIT".into(), defines.radix_bits_per_digit),
ShaderDefVal::UInt("RADIX_DIGIT_PLACES".into(), defines.radix_digit_places),
ShaderDefVal::UInt("ENTRIES_PER_INVOCATION_A".into(), defines.entries_per_invocation_a),
ShaderDefVal::UInt("ENTRIES_PER_INVOCATION_C".into(), defines.entries_per_invocation_c),
ShaderDefVal::UInt("WORKGROUP_INVOCATIONS_A".into(), defines.workgroup_invocations_a),
ShaderDefVal::UInt("WORKGROUP_INVOCATIONS_C".into(), defines.workgroup_invocations_c),
ShaderDefVal::UInt("WORKGROUP_ENTRIES_C".into(), defines.workgroup_entries_c),
ShaderDefVal::UInt("MAX_TILE_COUNT_C".into(), defines.max_tile_count_c),
ShaderDefVal::UInt("TEMPORAL_SORT_WINDOW_SIZE".into(), defines.temporal_sort_window_size),
];
if aabb {
shader_defs.push("USE_AABB".into());
}
if !aabb {
shader_defs.push("USE_OBB".into());
}
if visualize_bounding_box {
shader_defs.push("VISUALIZE_BOUNDING_BOX".into());
}
shader_defs
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub struct GaussianCloudPipelineKey {
pub aabb: bool,
bevy_gaussian_splatting/tests/gpu/gaussian.rs
Line 146 in 5c9a20a
use std::sync::{
Arc,
Mutex,
};
use bevy::{
prelude::*,
app::AppExit,
core::FrameCount,
render::view::screenshot::ScreenshotManager,
window::PrimaryWindow,
};
use bevy_gaussian_splatting::{
GaussianCloud,
GaussianSplattingBundle,
random_gaussians,
};
use _harness::{
TestHarness,
test_harness_app,
TestStateArc,
};
mod _harness;
fn main() {
let mut app = test_harness_app(TestHarness {
resolution: (512.0, 512.0),
});
app.add_systems(Startup, setup);
app.add_systems(Update, capture_ready);
app.run();
}
fn setup(
mut commands: Commands,
mut gaussian_assets: ResMut<Assets<GaussianCloud>>,
) {
let cloud = gaussian_assets.add(random_gaussians(10000));
commands.spawn((
GaussianSplattingBundle {
cloud,
..default()
},
Name::new("gaussian_cloud"),
));
commands.spawn((
Camera3dBundle {
transform: Transform::from_translation(Vec3::new(0.0, 1.5, 5.0)),
..default()
},
));
}
fn check_image_equality(image: &Image, other: &Image) -> bool {
if image.width() != other.width() || image.height() != other.height() {
return false;
}
for (word, other_word) in image.data.iter().zip(other.data.iter()) {
if word != other_word {
return false;
}
}
true
}
fn test_stability(captures: Arc<Mutex<Vec<Image>>>) {
let all_frames_similar = captures.lock().unwrap().iter()
.fold(Some(None), |acc, image| {
match acc {
Some(acc_image) => {
if let Some(acc_image) = acc_image {
if check_image_equality(acc_image, image) {
Some(Some(acc_image))
} else {
None
}
} else {
Some(Some(image))
}
},
None => None,
}
}).is_some();
assert!(all_frames_similar, "all frames are not the same");
}
fn save_captures(captures: Arc<Mutex<Vec<Image>>>) {
captures.lock().unwrap().iter()
.enumerate()
.for_each(|(i, image)| {
let path = format!("target/tmp/test_gaussian_frame_{}.png", i);
let dyn_img = image.clone().try_into_dynamic().unwrap();
let img = dyn_img.to_rgba8();
img.save(path).unwrap();
});
}
fn capture_ready(
// gaussian_cloud_assets: Res<Assets<GaussianCloud>>,
// asset_server: Res<AssetServer>,
// gaussian_clouds: Query<
// Entity,
// &Handle<GaussianCloud>,
// >,
main_window: Query<Entity, With<PrimaryWindow>>,
mut screenshot_manager: ResMut<ScreenshotManager>,
mut exit: EventWriter<AppExit>,
frame_count: Res<FrameCount>,
state: Local<TestStateArc>,
buffer: Local<Arc<Mutex<Vec<Image>>>>,
) {
let buffer = buffer.to_owned();
let buffer_frames = 10;
let wait_frames = 10; // wait for gaussian cloud to load
if frame_count.0 < wait_frames {
return;
}
let state_clone = Arc::clone(&state);
let buffer_clone = Arc::clone(&buffer);
let mut state = state.lock().unwrap();
state.test_loaded = true;
if state.test_completed {
{
let captures = buffer.lock().unwrap();
let frame_count = captures.len();
assert_eq!(frame_count, buffer_frames, "captured {} frames, expected {}", frame_count, buffer_frames);
}
save_captures(buffer.clone());
test_stability(buffer);
// TODO: add correctness test (use CPU gaussian pipeline to compare results)
exit.send(AppExit);
return;
}
if let Ok(window_entity) = main_window.get_single() {
screenshot_manager.take_screenshot(window_entity, move |image: Image| {
let has_non_zero_data = image.data.iter().fold(false, |non_zero, &x| non_zero || x != 0);
assert!(has_non_zero_data, "screenshot is all zeros");
let mut buffer = buffer_clone.lock().unwrap();
buffer.push(image);
if buffer.len() >= buffer_frames {
let mut state = state_clone.lock().unwrap();
state.test_completed = true;
}
}).unwrap();
}
}
bevy_gaussian_splatting/src/render/mod.rs
Line 988 in 507a28d
None => return RenderCommandResult::Failure,
};
pass.set_bind_group(2, &bind_groups.cloud_bind_group, &[]); // TODO: abstract source of cloud_bind_group (e.g. packed vs. planar)
pass.set_bind_group(3, &bind_groups.sorted_bind_group, &[]);
#[cfg(feature = "webgl2")]
pass.draw(0..4, 0..gpu_gaussian_cloud.count as u32);
#[cfg(not(feature = "webgl2"))]
pass.draw_indirect(&gpu_gaussian_cloud.draw_indirect_buffer, 0);
RenderCommandResult::Success
bevy_gaussian_splatting/src/sort/rayon.rs
Line 101 in 5c9a20a
use bevy::{
prelude::*,
asset::LoadState,
utils::Instant,
};
use rayon::prelude::*;
use crate::{
GaussianCloud,
GaussianCloudSettings,
sort::{
SortedEntries,
SortMode,
},
};
#[derive(Default)]
pub struct RayonSortPlugin;
impl Plugin for RayonSortPlugin {
fn build(&self, app: &mut App) {
app.add_systems(Update, rayon_sort);
}
}
pub fn rayon_sort(
asset_server: Res<AssetServer>,
gaussian_clouds_res: Res<Assets<GaussianCloud>>,
mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
gaussian_clouds: Query<(
&Handle<GaussianCloud>,
&Handle<SortedEntries>,
&GaussianCloudSettings,
)>,
cameras: Query<(
&GlobalTransform,
&Camera3d,
)>,
mut last_camera_position: Local<Vec3>,
mut last_sort_time: Local<Option<Instant>>,
) {
let period = std::time::Duration::from_millis(100);
if let Some(last_sort_time) = last_sort_time.as_ref() {
if last_sort_time.elapsed() < period {
return;
}
}
// TODO: move sort to render world, use extracted views and update the existing buffer instead of creating new
for (
camera_transform,
_camera,
) in cameras.iter() {
let camera_position = camera_transform.compute_transform().translation;
if *last_camera_position == camera_position {
return;
}
for (
gaussian_cloud_handle,
sorted_entries_handle,
settings,
) in gaussian_clouds.iter() {
if settings.sort_mode != SortMode::Rayon {
continue;
}
if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) {
continue;
}
if Some(LoadState::Loading) == asset_server.get_load_state(sorted_entries_handle) {
continue;
}
if let Some(gaussian_cloud) = gaussian_clouds_res.get(gaussian_cloud_handle) {
if let Some(sorted_entries) = sorted_entries_res.get_mut(sorted_entries_handle) {
assert_eq!(gaussian_cloud.gaussians.len(), sorted_entries.sorted.len());
*last_camera_position = camera_position;
*last_sort_time = Some(Instant::now());
gaussian_cloud.gaussians.par_iter()
.zip(sorted_entries.sorted.par_iter_mut())
.enumerate()
.for_each(|(idx, (gaussian, sort_entry))| {
let position = Vec3::from_slice(gaussian.position.as_ref());
let delta = camera_position - position;
sort_entry.key = bytemuck::cast(delta.length_squared());
sort_entry.index = idx as u32;
});
sorted_entries.sorted.par_sort_unstable_by(|a, b| {
bytemuck::cast::<u32, f32>(b.key).partial_cmp(&bytemuck::cast::<u32, f32>(a.key)).unwrap()
});
// TODO: update DrawIndirect buffer during sort phase (GPU sort will override default DrawIndirect)
}
}
}
}
}
bevy_gaussian_splatting/src/render/mod.rs
Line 223 in 507a28d
usage: BufferUsages::INDIRECT | BufferUsages::COPY_DST | BufferUsages::STORAGE | BufferUsages::COPY_SRC,
});
// TODO: (extract GaussianCloud, TextureBuffers) when feature buffer_texture is enabled
Ok(GpuGaussianCloud {
count,
draw_indirect_buffer,
#[cfg(feature = "debug_gpu")]
debug_gpu: gaussian_cloud,
#[cfg(feature = "packed")]
packed: packed::prepare_cloud(render_device, &gaussian_cloud),
#[cfg(feature = "buffer_storage")]
planar: planar::prepare_cloud(render_device, &gaussian_cloud),
})
}
}
bevy_gaussian_splatting/src/render/mod.rs
Line 759 in 507a28d
asset_server: Res<AssetServer>,
gaussian_cloud_res: Res<RenderAssets<GaussianCloud>>,
sorted_entries_res: Res<RenderAssets<SortedEntries>>,
#[cfg(feature = "buffer_storage")]
gaussian_clouds: Query<(
Entity,
&Handle<GaussianCloud>,
&Handle<SortedEntries>,
)>,
#[cfg(feature = "buffer_texture")]
gaussian_clouds: Query<(
Entity,
&Handle<GaussianCloud>,
&Handle<SortedEntries>,
&texture::GpuTextureBuffers,
)>,
#[cfg(feature = "buffer_texture")]
gpu_images: Res<RenderAssets<Image>>,
) {
let Some(model) = gaussian_uniforms.buffer() else {
return;
};
// TODO: overloaded system, move to resource setup system
groups.base_bind_group = Some(render_device.create_bind_group(
"gaussian_uniform_bind_group",
&gaussian_cloud_pipeline.gaussian_uniform_layout,
println!("initial cloud size: {}", cloud.len());
cloud = (0..cloud.len())
.filter(|&idx| {
is_point_in_transformed_sphere(
cloud.position(idx),
)
})
.map(|idx| cloud.gaussian(idx))
.collect();
println!("filtered position cloud size: {}", cloud.len());
let file = std::fs::File::open(&filename).expect("failed to open file");
let mut reader = std::io::BufReader::new(file);
let mut cloud = GaussianCloud::from_gaussians(
parse_ply(&mut reader).expect("failed to parse ply file"),
);
// TODO: prioritize mesh selection over export filter
// println!("initial cloud size: {}", cloud.len());
// cloud = (0..cloud.len())
// .filter(|&idx| {
// is_point_in_transformed_sphere(
// cloud.position(idx),
// )
// })
// .map(|idx| cloud.gaussian(idx))
// .collect();
// println!("filtered position cloud size: {}", cloud.len());
#[cfg(feature = "query_sparse")]
{
let sparse_selection = SparseSelect::default().select(&cloud).invert(cloud.len());
cloud = sparse_selection.indicies.iter()
.map(|idx| cloud.gaussian(*idx))
.collect();
println!("sparsity filtered cloud size: {}", cloud.len());
}
let base_filename = filename.split('.').next().expect("no extension").to_string();
let gcloud_filename = base_filename + ".gcloud";
write_gaussian_cloud_to_file(&cloud, &gcloud_filename);
let post_encode_bytes = Byte::from_u64(std::fs::metadata(&gcloud_filename).expect("failed to get metadata").len());
println!("output file size: {}", post_encode_bytes.get_appropriate_unit(UnitType::Decimal));
}
bevy_gaussian_splatting/src/render/texture.rs
Line 196 in 507a28d
use bevy::{
prelude::*,
asset::LoadState,
ecs::query::QueryItem,
render::{
extract_component::{
ExtractComponent,
ExtractComponentPlugin,
},
Render,
RenderApp,
RenderSet,
render_asset::RenderAssets,
render_resource::{
BindGroup,
BindGroupLayout,
BindGroupLayoutDescriptor,
BindGroupLayoutEntry,
BindGroupEntry,
BindingType,
BindingResource,
Extent3d,
TextureDimension,
TextureFormat,
TextureSampleType,
TextureUsages,
TextureViewDimension,
ShaderStages,
},
renderer::RenderDevice,
},
};
use static_assertions::assert_cfg;
#[allow(unused_imports)]
use crate::{
gaussian::{
cloud::GaussianCloud,
f32::{
PositionVisibility,
Rotation,
ScaleOpacity,
},
},
material::spherical_harmonics::{
SH_COEFF_COUNT,
SH_VEC4_PLANES,
SphericalHarmonicCoefficients,
},
render::{
GaussianCloudPipeline,
GpuGaussianCloud,
},
};
// TODO: support loading from directory of images
assert_cfg!(
feature = "planar",
"texture rendering is only supported with the `planar` feature enabled",
);
assert_cfg!(
not(feature = "f32"),
"f32 texture support is not implemented yet",
);
#[derive(Component, Clone, Debug, Reflect)]
pub struct TextureBuffers {
position_visibility: Handle<Image>,
spherical_harmonics: Handle<Image>,
#[cfg(feature = "f16")]
rotation_scale_opacity: Handle<Image>,
#[cfg(feature = "f32")]
rotation: Handle<Image>,
#[cfg(feature = "f32")]
scale_opacity: Handle<Image>,
}
impl ExtractComponent for TextureBuffers {
type Query = &'static Self;
type Filter = ();
type Out = Self;
fn extract_component(texture_buffers: QueryItem<'_, Self::Query>) -> Option<Self::Out> {
texture_buffers.clone().into()
}
}
#[derive(Default)]
pub struct BufferTexturePlugin;
impl Plugin for BufferTexturePlugin {
fn build(&self, app: &mut App) {
app.register_type::<TextureBuffers>();
app.add_plugins(ExtractComponentPlugin::<TextureBuffers>::default());
app.add_systems(Update, queue_textures);
let render_app = app.sub_app_mut(RenderApp);
render_app.add_systems(
Render,
queue_gpu_texture_buffers.in_set(RenderSet::PrepareAssets),
);
}
}
#[derive(Component, Clone, Debug)]
pub struct GpuTextureBuffers {
pub bind_group: BindGroup,
}
pub fn queue_gpu_texture_buffers(
mut commands: Commands,
// gaussian_cloud_pipeline: Res<GaussianCloudPipeline>,
pipeline: ResMut<GaussianCloudPipeline>,
render_device: ResMut<RenderDevice>,
gpu_images: Res<RenderAssets<Image>>,
clouds: Query<(
Entity,
&TextureBuffers,
)>,
) {
for (entity, texture_buffers,) in clouds.iter() {
#[cfg(feature = "f16")]
let bind_group = render_device.create_bind_group(
Some("texture_gaussian_cloud_bind_group"),
&pipeline.gaussian_cloud_layout,
&[
BindGroupEntry {
binding: 0,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.position_visibility).unwrap().texture_view
),
},
BindGroupEntry {
binding: 1,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.spherical_harmonics).unwrap().texture_view
),
},
BindGroupEntry {
binding: 2,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.rotation_scale_opacity).unwrap().texture_view
),
},
],
);
#[cfg(feature = "f32")]
let bind_group = render_device.create_bind_group(
Some("texture_gaussian_cloud_bind_group"),
&pipeline.gaussian_cloud_layout,
&[
BindGroupEntry {
binding: 0,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.position_visibility).unwrap().texture_view
),
},
BindGroupEntry {
binding: 1,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.spherical_harmonics).unwrap().texture_view
),
},
BindGroupEntry {
binding: 2,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.rotation).unwrap().texture_view
),
},
BindGroupEntry {
binding: 3,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.scale_opacity).unwrap().texture_view
),
},
],
);
commands.entity(entity).insert(GpuTextureBuffers { bind_group });
}
}
// TODO: support asset change detection and reupload
fn queue_textures(
mut commands: Commands,
asset_server: Res<AssetServer>,
gaussian_cloud_res: Res<Assets<GaussianCloud>>,
mut images: ResMut<Assets<Image>>,
clouds: Query<(
Entity,
&Handle<GaussianCloud>,
Without<TextureBuffers>,
)>,
) {
for (entity, cloud_handle, _) in clouds.iter() {
if Some(LoadState::Loading) == asset_server.get_load_state(cloud_handle){
continue;
}
if gaussian_cloud_res.get(cloud_handle).is_none() {
continue;
}
let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();
let square = cloud.len_sqrt_ceil() as u32;
let extent_1d = Extent3d {
width: square,
height: square, // TODO: shrink height to save memory (consider fixed width)
depth_or_array_layers: 1,
};
let mut position_visibility = Image::new(
extent_1d,
TextureDimension::D2,
bytemuck::cast_slice(cloud.position_visibility.as_slice()).to_vec(),
TextureFormat::Rgba32Float,
);
position_visibility.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
let position_visibility = images.add(position_visibility);
let texture_buffers: TextureBuffers;
#[cfg(feature = "f16")]
{
let planar_spherical_harmonics: Vec<u32> = (0..SH_VEC4_PLANES)
.flat_map(|plane_index| {
cloud.spherical_harmonic.iter()
.flat_map(move |sh| {
let start_index = plane_index * 4;
let end_index = std::cmp::min(start_index + 4, sh.coefficients.len());
let mut depthwise = sh.coefficients[start_index..end_index].to_vec();
depthwise.resize(4, 0);
depthwise
})
})
.collect();
let mut spherical_harmonics = Image::new(
Extent3d {
width: square,
height: square,
depth_or_array_layers: SH_VEC4_PLANES as u32,
},
TextureDimension::D2,
bytemuck::cast_slice(planar_spherical_harmonics.as_slice()).to_vec(),
TextureFormat::Rgba32Uint,
);
spherical_harmonics.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
let spherical_harmonics = images.add(spherical_harmonics);
let mut rotation_scale_opacity = Image::new(
extent_1d,
TextureDimension::D2,
bytemuck::cast_slice(cloud.rotation_scale_opacity_packed128.as_slice()).to_vec(),
TextureFormat::Rgba32Uint,
);
rotation_scale_opacity.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
let rotation_scale_opacity = images.add(rotation_scale_opacity);
texture_buffers = TextureBuffers {
position_visibility,
spherical_harmonics,
rotation_scale_opacity,
};
}
#[cfg(feature = "f32")]
{
texture_buffers = TextureBuffers {
position_visibility,
spherical_harmonics: todo!(),
rotation: todo!(),
scale_opacity: todo!(),
};
}
commands.entity(entity).insert(texture_buffers);
}
}
pub fn get_sorted_bind_group_layout(
render_device: &RenderDevice,
) -> BindGroupLayout {
render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("texture_sorted_layout"),
entries: &[
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::all(),
ty: BindingType::Texture {
view_dimension: TextureViewDimension::D2,
sample_type: TextureSampleType::Uint,
multisampled: false,
},
count: None,
},
],
})
}
#[cfg(feature = "f16")]
pub fn get_bind_group_layout(
render_device: &RenderDevice,
_read_only: bool
) -> BindGroupLayout {
let sh_view_dimension = if SH_VEC4_PLANES == 1 {
TextureViewDimension::D2
} else {
TextureViewDimension::D2Array
};
render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("texture_f16_gaussian_cloud_layout"),
entries: &[
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::all(),
ty: BindingType::Texture {
view_dimension: TextureViewDimension::D2,
sample_type: TextureSampleType::Float {
filterable: false,
},
multisampled: false,
},
count: None,
},
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::all(),
ty: BindingType::Texture {
view_dimension: sh_view_dimension,
sample_type: TextureSampleType::Uint,
multisampled: false,
},
count: None,
},
BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::all(),
ty: BindingType::Texture {
view_dimension: TextureViewDimension::D2,
sample_type: TextureSampleType::Uint,
multisampled: false,
},
count: None,
},
],
})
}
#[cfg(feature = "f32")]
pub fn get_bind_group_layout(
render_device: &RenderDevice,
read_only: bool
) -> BindGroupLayout {
render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("texture_f32_gaussian_cloud_layout"),
entries: &[
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::all(),
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<PositionVisibility>() as u64),
},
count: None,
},
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::all(),
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<SphericalHarmonicCoefficients>() as u64),
},
count: None,
},
BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::all(),
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<Rotation>() as u64),
},
count: None,
},
BindGroupLayoutEntry {
binding: 3,
visibility: ShaderStages::all(),
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<ScaleOpacity>() as u64),
},
count: None,
},
],
})
}
gaussian_cloud: Self::ExtractedAsset,
render_device: &mut SystemParamItem<Self::Param>,
) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
let gaussian_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
label: Some("gaussian cloud buffer"),
contents: bytemuck::cast_slice(gaussian_cloud.0.as_slice()),
usage: BufferUsages::VERTEX | BufferUsages::COPY_DST | BufferUsages::STORAGE,
});
let count = gaussian_cloud.0.len() as u32;
// TODO: derive sorting_buffer_size from cloud count (with possible rounding to next power of 2)
let sorting_global_buffer = render_device.create_buffer(&BufferDescriptor {
label: Some("sorting global buffer"),
size: ShaderDefines::default().sorting_buffer_size as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let draw_indirect_buffer = render_device.create_buffer(&BufferDescriptor {
label: Some("draw indirect buffer"),
size: std::mem::size_of::<wgpu::util::DrawIndirect>() as u64,
usage: BufferUsages::INDIRECT | BufferUsages::COPY_DST | BufferUsages::STORAGE | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let sorting_pass_buffers = (0..4)
.map(|idx| {
render_device.create_buffer_with_data(&BufferInitDescriptor {
label: format!("sorting pass buffer {}", idx).as_str().into(),
contents: &[idx as u8, 0, 0, 0],
usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
})
})
.collect::<Vec<Buffer>>()
.try_into()
.unwrap();
let entry_buffer_a = render_device.create_buffer(&BufferDescriptor {
label: Some("entry buffer a"),
size: (count as usize * std::mem::size_of::<(u32, u32)>()) as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let entry_buffer_b = render_device.create_buffer(&BufferDescriptor {
label: Some("entry buffer b"),
size: (count as usize * std::mem::size_of::<(u32, u32)>()) as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
Ok(GpuGaussianCloud {
gaussian_buffer,
count,
draw_indirect_buffer,
sorting_global_buffer,
sorting_pass_buffers,
entry_buffer_a,
entry_buffer_b,
})
}
}
bevy_gaussian_splatting/src/sort/mod.rs
Line 134 in 5c9a20a
use bevy::{
prelude::*,
asset::LoadState,
ecs::system::{
lifetimeless::SRes,
SystemParamItem,
},
reflect::TypeUuid,
render::{
render_asset::{
RenderAsset,
RenderAssetPlugin,
PrepareAssetError,
},
render_resource::*,
renderer::RenderDevice,
},
};
use bytemuck::{
Pod,
Zeroable,
};
use static_assertions::assert_cfg;
use crate::{
GaussianCloud,
GaussianCloudSettings,
};
#[cfg(feature = "sort_radix")]
pub mod radix;
#[cfg(feature = "sort_rayon")]
pub mod rayon;
assert_cfg!(
any(
feature = "sort_radix",
feature = "sort_rayon",
),
"no sort mode enabled",
);
#[derive(
Component,
Debug,
Clone,
PartialEq,
Reflect,
)]
pub enum SortMode {
None,
#[cfg(feature = "sort_radix")]
Radix,
#[cfg(feature = "sort_rayon")]
Rayon,
}
impl Default for SortMode {
#[allow(unreachable_code)]
fn default() -> Self {
#[cfg(feature = "sort_rayon")]
return Self::Rayon;
#[cfg(feature = "sort_radix")]
return Self::Radix;
Self::None
}
}
#[derive(Default)]
pub struct SortPlugin;
impl Plugin for SortPlugin {
fn build(&self, app: &mut App) {
#[cfg(feature = "sort_radix")]
app.add_plugins(radix::RadixSortPlugin);
#[cfg(feature = "sort_rayon")]
app.add_plugins(rayon::RayonSortPlugin);
app.register_type::<SortedEntries>();
app.init_asset::<SortedEntries>();
app.register_asset_reflect::<SortedEntries>();
app.add_plugins(RenderAssetPlugin::<SortedEntries>::default());
app.add_systems(Update, auto_insert_sorted_entries);
}
}
fn auto_insert_sorted_entries(
mut commands: Commands,
asset_server: Res<AssetServer>,
gaussian_clouds_res: Res<Assets<GaussianCloud>>,
mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
gaussian_clouds: Query<(
Entity,
&Handle<GaussianCloud>,
&GaussianCloudSettings,
Without<Handle<SortedEntries>>,
)>,
) {
for (
entity,
gaussian_cloud_handle,
_settings,
_,
) in gaussian_clouds.iter() {
// // TODO: specialize vertex shader for sort mode (e.g. draw_indirect but no sort indirection)
// if settings.sort_mode == SortMode::None {
// continue;
// }
if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) {
continue;
}
let cloud = gaussian_clouds_res.get(gaussian_cloud_handle);
if cloud.is_none() {
continue;
}
let cloud = cloud.unwrap();
// TODO: move gaussian_cloud and sorted_entry assets into an asset bundle
let sorted_entries = sorted_entries_res.add(SortedEntries {
sorted: (0..cloud.gaussians.len())
.map(|idx| {
SortEntry {
key: 1,
index: idx as u32,
}
})
.collect(),
});
commands.entity(entity)
.insert(sorted_entries);
}
}
#[derive(
Clone,
Copy,
Debug,
Default,
PartialEq,
Reflect,
ShaderType,
Pod,
Zeroable,
)]
#[repr(C)]
pub struct SortEntry {
pub key: u32,
pub index: u32,
}
// TODO: add RenderAssetPlugin for SortedEntries & auto-insert to GaussianCloudBundles if their sort mode is not None
// supports pre-sorting or CPU sorting in main world, initializes the sorting_entry_buffer
#[derive(
Clone,
Asset,
Debug,
Default,
PartialEq,
Reflect,
TypeUuid,
)]
#[uuid = "ac2f08eb-fa13-ccdd-ea11-51571ea332d5"]
pub struct SortedEntries {
pub sorted: Vec<SortEntry>,
}
impl RenderAsset for SortedEntries {
type ExtractedAsset = SortedEntries;
type PreparedAsset = GpuSortedEntry;
type Param = SRes<RenderDevice>;
fn extract_asset(&self) -> Self::ExtractedAsset {
self.clone()
}
fn prepare_asset(
sorted_entries: Self::ExtractedAsset,
render_device: &mut SystemParamItem<Self::Param>,
) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
let sorted_entry_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
label: Some("sorted_entry_buffer"),
contents: bytemuck::cast_slice(sorted_entries.sorted.as_slice()),
usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST | BufferUsages::STORAGE,
});
let count = sorted_entries.sorted.len();
Ok(GpuSortedEntry {
sorted_entry_buffer,
count,
})
}
}
// TODO: support instancing and multiple cameras
// separate entry_buffer_a binding into unique a bind group to optimize buffer updates
#[derive(Debug, Clone)]
pub struct GpuSortedEntry {
pub sorted_entry_buffer: Buffer,
pub count: usize,
}
if settings.sort_mode == SortMode::None {
continue;
}
bevy_gaussian_splatting/src/sort/mod.rs
Line 119 in 5c9a20a
use bevy::{
prelude::*,
asset::LoadState,
ecs::system::{
lifetimeless::SRes,
SystemParamItem,
},
reflect::TypeUuid,
render::{
render_asset::{
RenderAsset,
RenderAssetPlugin,
PrepareAssetError,
},
render_resource::*,
renderer::RenderDevice,
},
};
use bytemuck::{
Pod,
Zeroable,
};
use static_assertions::assert_cfg;
use crate::{
GaussianCloud,
GaussianCloudSettings,
};
#[cfg(feature = "sort_radix")]
pub mod radix;
#[cfg(feature = "sort_rayon")]
pub mod rayon;
assert_cfg!(
any(
feature = "sort_radix",
feature = "sort_rayon",
),
"no sort mode enabled",
);
#[derive(
Component,
Debug,
Clone,
PartialEq,
Reflect,
)]
pub enum SortMode {
None,
#[cfg(feature = "sort_radix")]
Radix,
#[cfg(feature = "sort_rayon")]
Rayon,
}
impl Default for SortMode {
#[allow(unreachable_code)]
fn default() -> Self {
#[cfg(feature = "sort_rayon")]
return Self::Rayon;
#[cfg(feature = "sort_radix")]
return Self::Radix;
Self::None
}
}
#[derive(Default)]
pub struct SortPlugin;
impl Plugin for SortPlugin {
fn build(&self, app: &mut App) {
#[cfg(feature = "sort_radix")]
app.add_plugins(radix::RadixSortPlugin);
#[cfg(feature = "sort_rayon")]
app.add_plugins(rayon::RayonSortPlugin);
app.register_type::<SortedEntries>();
app.init_asset::<SortedEntries>();
app.register_asset_reflect::<SortedEntries>();
app.add_plugins(RenderAssetPlugin::<SortedEntries>::default());
app.add_systems(Update, auto_insert_sorted_entries);
}
}
fn auto_insert_sorted_entries(
mut commands: Commands,
asset_server: Res<AssetServer>,
gaussian_clouds_res: Res<Assets<GaussianCloud>>,
mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
gaussian_clouds: Query<(
Entity,
&Handle<GaussianCloud>,
&GaussianCloudSettings,
Without<Handle<SortedEntries>>,
)>,
) {
for (
entity,
gaussian_cloud_handle,
_settings,
_,
) in gaussian_clouds.iter() {
// // TODO: specialize vertex shader for sort mode (e.g. draw_indirect but no sort indirection)
// if settings.sort_mode == SortMode::None {
// continue;
// }
if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) {
continue;
}
let cloud = gaussian_clouds_res.get(gaussian_cloud_handle);
if cloud.is_none() {
continue;
}
let cloud = cloud.unwrap();
// TODO: move gaussian_cloud and sorted_entry assets into an asset bundle
let sorted_entries = sorted_entries_res.add(SortedEntries {
sorted: (0..cloud.gaussians.len())
.map(|idx| {
SortEntry {
key: 1,
index: idx as u32,
}
})
.collect(),
});
commands.entity(entity)
.insert(sorted_entries);
}
}
#[derive(
Clone,
Copy,
Debug,
Default,
PartialEq,
Reflect,
ShaderType,
Pod,
Zeroable,
)]
#[repr(C)]
pub struct SortEntry {
pub key: u32,
pub index: u32,
}
// TODO: add RenderAssetPlugin for SortedEntries & auto-insert to GaussianCloudBundles if their sort mode is not None
// supports pre-sorting or CPU sorting in main world, initializes the sorting_entry_buffer
#[derive(
Clone,
Asset,
Debug,
Default,
PartialEq,
Reflect,
TypeUuid,
)]
#[uuid = "ac2f08eb-fa13-ccdd-ea11-51571ea332d5"]
pub struct SortedEntries {
pub sorted: Vec<SortEntry>,
}
impl RenderAsset for SortedEntries {
type ExtractedAsset = SortedEntries;
type PreparedAsset = GpuSortedEntry;
type Param = SRes<RenderDevice>;
fn extract_asset(&self) -> Self::ExtractedAsset {
self.clone()
}
fn prepare_asset(
sorted_entries: Self::ExtractedAsset,
render_device: &mut SystemParamItem<Self::Param>,
) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
let sorted_entry_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
label: Some("sorted_entry_buffer"),
contents: bytemuck::cast_slice(sorted_entries.sorted.as_slice()),
usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST | BufferUsages::STORAGE,
});
let count = sorted_entries.sorted.len();
Ok(GpuSortedEntry {
sorted_entry_buffer,
count,
})
}
}
// TODO: support instancing and multiple cameras
// separate entry_buffer_a binding into unique a bind group to optimize buffer updates
#[derive(Debug, Clone)]
pub struct GpuSortedEntry {
pub sorted_entry_buffer: Buffer,
pub count: usize,
}
bevy_gaussian_splatting/src/render/texture.rs
Line 222 in 507a28d
use bevy::{
prelude::*,
asset::LoadState,
ecs::query::QueryItem,
render::{
extract_component::{
ExtractComponent,
ExtractComponentPlugin,
},
Render,
RenderApp,
RenderSet,
render_asset::RenderAssets,
render_resource::{
BindGroup,
BindGroupLayout,
BindGroupLayoutDescriptor,
BindGroupLayoutEntry,
BindGroupEntry,
BindingType,
BindingResource,
Extent3d,
TextureDimension,
TextureFormat,
TextureSampleType,
TextureUsages,
TextureViewDimension,
ShaderStages,
},
renderer::RenderDevice,
},
};
use static_assertions::assert_cfg;
#[allow(unused_imports)]
use crate::{
gaussian::{
cloud::GaussianCloud,
f32::{
PositionVisibility,
Rotation,
ScaleOpacity,
},
},
material::spherical_harmonics::{
SH_COEFF_COUNT,
SH_VEC4_PLANES,
SphericalHarmonicCoefficients,
},
render::{
GaussianCloudPipeline,
GpuGaussianCloud,
},
};
// TODO: support loading from directory of images
assert_cfg!(
feature = "planar",
"texture rendering is only supported with the `planar` feature enabled",
);
assert_cfg!(
not(feature = "f32"),
"f32 texture support is not implemented yet",
);
#[derive(Component, Clone, Debug, Reflect)]
pub struct TextureBuffers {
position_visibility: Handle<Image>,
spherical_harmonics: Handle<Image>,
#[cfg(feature = "f16")]
rotation_scale_opacity: Handle<Image>,
#[cfg(feature = "f32")]
rotation: Handle<Image>,
#[cfg(feature = "f32")]
scale_opacity: Handle<Image>,
}
impl ExtractComponent for TextureBuffers {
type Query = &'static Self;
type Filter = ();
type Out = Self;
fn extract_component(texture_buffers: QueryItem<'_, Self::Query>) -> Option<Self::Out> {
texture_buffers.clone().into()
}
}
#[derive(Default)]
pub struct BufferTexturePlugin;
impl Plugin for BufferTexturePlugin {
fn build(&self, app: &mut App) {
app.register_type::<TextureBuffers>();
app.add_plugins(ExtractComponentPlugin::<TextureBuffers>::default());
app.add_systems(Update, queue_textures);
let render_app = app.sub_app_mut(RenderApp);
render_app.add_systems(
Render,
queue_gpu_texture_buffers.in_set(RenderSet::PrepareAssets),
);
}
}
#[derive(Component, Clone, Debug)]
pub struct GpuTextureBuffers {
pub bind_group: BindGroup,
}
pub fn queue_gpu_texture_buffers(
mut commands: Commands,
// gaussian_cloud_pipeline: Res<GaussianCloudPipeline>,
pipeline: ResMut<GaussianCloudPipeline>,
render_device: ResMut<RenderDevice>,
gpu_images: Res<RenderAssets<Image>>,
clouds: Query<(
Entity,
&TextureBuffers,
)>,
) {
for (entity, texture_buffers,) in clouds.iter() {
#[cfg(feature = "f16")]
let bind_group = render_device.create_bind_group(
Some("texture_gaussian_cloud_bind_group"),
&pipeline.gaussian_cloud_layout,
&[
BindGroupEntry {
binding: 0,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.position_visibility).unwrap().texture_view
),
},
BindGroupEntry {
binding: 1,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.spherical_harmonics).unwrap().texture_view
),
},
BindGroupEntry {
binding: 2,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.rotation_scale_opacity).unwrap().texture_view
),
},
],
);
#[cfg(feature = "f32")]
let bind_group = render_device.create_bind_group(
Some("texture_gaussian_cloud_bind_group"),
&pipeline.gaussian_cloud_layout,
&[
BindGroupEntry {
binding: 0,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.position_visibility).unwrap().texture_view
),
},
BindGroupEntry {
binding: 1,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.spherical_harmonics).unwrap().texture_view
),
},
BindGroupEntry {
binding: 2,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.rotation).unwrap().texture_view
),
},
BindGroupEntry {
binding: 3,
resource: BindingResource::TextureView(
&gpu_images.get(&texture_buffers.scale_opacity).unwrap().texture_view
),
},
],
);
commands.entity(entity).insert(GpuTextureBuffers { bind_group });
}
}
// TODO: support asset change detection and reupload
fn queue_textures(
mut commands: Commands,
asset_server: Res<AssetServer>,
gaussian_cloud_res: Res<Assets<GaussianCloud>>,
mut images: ResMut<Assets<Image>>,
clouds: Query<(
Entity,
&Handle<GaussianCloud>,
Without<TextureBuffers>,
)>,
) {
for (entity, cloud_handle, _) in clouds.iter() {
if Some(LoadState::Loading) == asset_server.get_load_state(cloud_handle){
continue;
}
if gaussian_cloud_res.get(cloud_handle).is_none() {
continue;
}
let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();
let square = cloud.len_sqrt_ceil() as u32;
let extent_1d = Extent3d {
width: square,
height: square, // TODO: shrink height to save memory (consider fixed width)
depth_or_array_layers: 1,
};
let mut position_visibility = Image::new(
extent_1d,
TextureDimension::D2,
bytemuck::cast_slice(cloud.position_visibility.as_slice()).to_vec(),
TextureFormat::Rgba32Float,
);
position_visibility.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
let position_visibility = images.add(position_visibility);
let texture_buffers: TextureBuffers;
#[cfg(feature = "f16")]
{
let planar_spherical_harmonics: Vec<u32> = (0..SH_VEC4_PLANES)
.flat_map(|plane_index| {
cloud.spherical_harmonic.iter()
.flat_map(move |sh| {
let start_index = plane_index * 4;
let end_index = std::cmp::min(start_index + 4, sh.coefficients.len());
let mut depthwise = sh.coefficients[start_index..end_index].to_vec();
depthwise.resize(4, 0);
depthwise
})
})
.collect();
let mut spherical_harmonics = Image::new(
Extent3d {
width: square,
height: square,
depth_or_array_layers: SH_VEC4_PLANES as u32,
},
TextureDimension::D2,
bytemuck::cast_slice(planar_spherical_harmonics.as_slice()).to_vec(),
TextureFormat::Rgba32Uint,
);
spherical_harmonics.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
let spherical_harmonics = images.add(spherical_harmonics);
let mut rotation_scale_opacity = Image::new(
extent_1d,
TextureDimension::D2,
bytemuck::cast_slice(cloud.rotation_scale_opacity_packed128.as_slice()).to_vec(),
TextureFormat::Rgba32Uint,
);
rotation_scale_opacity.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
let rotation_scale_opacity = images.add(rotation_scale_opacity);
texture_buffers = TextureBuffers {
position_visibility,
spherical_harmonics,
rotation_scale_opacity,
};
}
#[cfg(feature = "f32")]
{
texture_buffers = TextureBuffers {
position_visibility,
spherical_harmonics: todo!(),
rotation: todo!(),
scale_opacity: todo!(),
};
}
commands.entity(entity).insert(texture_buffers);
}
}
pub fn get_sorted_bind_group_layout(
render_device: &RenderDevice,
) -> BindGroupLayout {
render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("texture_sorted_layout"),
entries: &[
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::all(),
ty: BindingType::Texture {
view_dimension: TextureViewDimension::D2,
sample_type: TextureSampleType::Uint,
multisampled: false,
},
count: None,
},
],
})
}
#[cfg(feature = "f16")]
pub fn get_bind_group_layout(
render_device: &RenderDevice,
_read_only: bool
) -> BindGroupLayout {
let sh_view_dimension = if SH_VEC4_PLANES == 1 {
TextureViewDimension::D2
} else {
TextureViewDimension::D2Array
};
render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("texture_f16_gaussian_cloud_layout"),
entries: &[
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::all(),
ty: BindingType::Texture {
view_dimension: TextureViewDimension::D2,
sample_type: TextureSampleType::Float {
filterable: false,
},
multisampled: false,
},
count: None,
},
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::all(),
ty: BindingType::Texture {
view_dimension: sh_view_dimension,
sample_type: TextureSampleType::Uint,
multisampled: false,
},
count: None,
},
BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::all(),
ty: BindingType::Texture {
view_dimension: TextureViewDimension::D2,
sample_type: TextureSampleType::Uint,
multisampled: false,
},
count: None,
},
],
})
}
#[cfg(feature = "f32")]
pub fn get_bind_group_layout(
render_device: &RenderDevice,
read_only: bool
) -> BindGroupLayout {
render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("texture_f32_gaussian_cloud_layout"),
entries: &[
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::all(),
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<PositionVisibility>() as u64),
},
count: None,
},
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::all(),
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<SphericalHarmonicCoefficients>() as u64),
},
count: None,
},
BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::all(),
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<Rotation>() as u64),
},
count: None,
},
BindGroupLayoutEntry {
binding: 3,
visibility: ShaderStages::all(),
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(std::mem::size_of::<ScaleOpacity>() as u64),
},
count: None,
},
],
})
}
&mut RenderPhase<Transparent3d>,
)>,
) {
// TODO: condition this system based on GaussianCloudBindGroup attachment
if gaussian_cloud_uniform.buffer().is_none() {
return;
};
let draw_custom = transparent_3d_draw_functions.read().id::<DrawGaussians>();
for (_view, mut transparent_phase) in &mut views {
create a pipeline to automatically optimize .ply
to .gcloud
asset format in release: https://bevyengine.org/news/bevy-0-12/#bevy-asset-v2
bevy_gaussian_splatting/src/render/mod.rs
Line 360 in 4356f87
],
});
GaussianCloudPipeline {
gaussian_cloud_layout,
gaussian_uniform_layout,
view_layout,
shader: GAUSSIAN_SHADER_HANDLE,
sorted_layout,
}
}
}
// TODO: allow setting shader defines via API
// TODO: separate shader defines for each pipeline
struct ShaderDefines {
radix_bits_per_digit: u32,
radix_digit_places: u32,
impl Plugin for GaussianSplattingPlugin {
fn build(&self, app: &mut App) {
// TODO: allow hot reloading of GaussianCloud handle through inspector UI
app.add_asset::<GaussianCloud>();
app.init_asset_loader::<GaussianCloudLoader>();
app.register_asset_reflect::<GaussianCloud>();
app.register_type::<GaussianCloudSettings>();
app.register_type::<GaussianSplattingBundle>();
app.add_plugins((
bevy_gaussian_splatting/src/render/mod.rs
Line 76 in 54e44af
};
// TODO: separate sort and render pipelines into separate files
const BINDINGS_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(675257236);
const GAUSSIAN_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(68294581);
const RADIX_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(6234673214);
const SPHERICAL_HARMONICS_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(834667312);
const TEMPORAL_SORT_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(1634543224);
const TRANSFORM_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(734523534);
pub mod node {
pub const RADIX_SORT: &str = "radix_sort";
bevy_gaussian_splatting/tests/gpu/radix.rs
Line 197 in 5c9a20a
use std::{
process::exit,
sync::{
Arc,
Mutex,
},
};
use bevy::{
prelude::*,
core::FrameCount,
core_pipeline::core_3d::{
CORE_3D,
Transparent3d,
},
render::{
RenderApp,
renderer::{
RenderContext,
RenderQueue,
},
render_asset::RenderAssets,
render_graph::{
Node,
NodeRunError,
RenderGraphApp,
RenderGraphContext,
},
render_phase::RenderPhase,
view::ExtractedView,
},
};
use bevy_gaussian_splatting::{
GaussianCloud,
GaussianSplattingBundle,
random_gaussians,
sort::SortedEntries,
};
use _harness::{
TestHarness,
test_harness_app,
TestState,
TestStateArc,
};
mod _harness;
pub mod node {
pub const RADIX_SORT_TEST: &str = "radix_sort_test";
}
fn main() {
let mut app = test_harness_app(TestHarness {
resolution: (512.0, 512.0),
});
app.add_systems(Startup, setup);
if let Ok(render_app) = app.get_sub_app_mut(RenderApp) {
render_app
.add_render_graph_node::<RadixTestNode>(
CORE_3D,
node::RADIX_SORT_TEST,
)
.add_render_graph_edge(
CORE_3D,
node::RADIX_SORT_TEST,
bevy::core_pipeline::core_3d::graph::node::END_MAIN_PASS,
);
}
app.run();
}
fn setup(
mut commands: Commands,
mut gaussian_assets: ResMut<Assets<GaussianCloud>>,
) {
let cloud = gaussian_assets.add(random_gaussians(10000));
commands.spawn((
GaussianSplattingBundle {
cloud,
settings: GaussianCloudSettings {
sort_mode: SortMode::Radix,
..default()
},
..default()
},
Name::new("gaussian_cloud"),
));
commands.spawn((
Camera3dBundle {
transform: Transform::from_translation(Vec3::new(0.0, 1.5, 5.0)),
..default()
},
));
}
pub struct RadixTestNode {
gaussian_clouds: QueryState<(
&'static Handle<GaussianCloud>,
&'static Handle<SortedEntries>,
)>,
state: TestStateArc,
views: QueryState<(
&'static ExtractedView,
&'static RenderPhase<Transparent3d>,
)>,
start_frame: u32,
}
impl FromWorld for RadixTestNode {
fn from_world(world: &mut World) -> Self {
Self {
gaussian_clouds: world.query(),
state: Arc::new(Mutex::new(TestState::default())),
views: world.query(),
start_frame: 0,
}
}
}
impl Node for RadixTestNode {
fn update(
&mut self,
world: &mut World,
) {
let mut state = self.state.lock().unwrap();
if state.test_completed {
exit(0);
}
if state.test_loaded && self.start_frame == 0 {
self.start_frame = world.get_resource::<FrameCount>().unwrap().0;
}
let frame_count = world.get_resource::<FrameCount>().unwrap().0;
const FRAME_LIMIT: u32 = 10;
if state.test_loaded && frame_count >= self.start_frame + FRAME_LIMIT {
state.test_completed = true;
}
self.gaussian_clouds.update_archetypes(world);
self.views.update_archetypes(world);
}
fn run(
&self,
_graph: &mut RenderGraphContext,
render_context: &mut RenderContext,
world: &World,
) -> Result<(), NodeRunError> {
for (view, _phase,) in self.views.iter_manual(world) {
let camera_position = view.transform.translation();
for (
cloud_handle,
sorted_entries_handle,
) in self.gaussian_clouds.iter_manual(world) {
let gaussian_cloud_res = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap();
let sorted_entries_res = world.get_resource::<RenderAssets<SortedEntries>>().unwrap();
let mut state = self.state.lock().unwrap();
if gaussian_cloud_res.get(cloud_handle).is_none() || sorted_entries_res.get(sorted_entries_handle).is_none() {
continue;
} else if !state.test_loaded {
state.test_loaded = true;
}
let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();
let sorted_entries = sorted_entries_res.get(sorted_entries_handle).unwrap();
let gaussians = cloud.debug_gpu.gaussians.clone();
wgpu::util::DownloadBuffer::read_buffer(
render_context.render_device().wgpu_device(),
world.get_resource::<RenderQueue>().unwrap().0.as_ref(),
&sorted_entries.sorted_entry_buffer.slice(
0..sorted_entries.sorted_entry_buffer.size()
),
move |buffer: Result<wgpu::util::DownloadBuffer, wgpu::BufferAsyncError>| {
let binding = buffer.unwrap();
let u32_muck = bytemuck::cast_slice::<u8, u32>(&*binding);
let mut radix_sorted_indices = Vec::new();
for i in (1..u32_muck.len()).step_by(2) {
radix_sorted_indices.push((i, u32_muck[i] as usize));
}
// TODO: depth order validation over ndc cells
radix_sorted_indices.iter()
.fold(0.0, |depth_acc, &(entry_idx, idx)| {
if idx == 0 || u32_muck[entry_idx - 1] == 0xffffffff || u32_muck[entry_idx - 1] == 0x0 {
return depth_acc;
}
let position = gaussians[idx].position;
let position_vec3 = Vec3::new(position[0], position[1], position[2]);
let depth = (position_vec3 - camera_position).length();
let depth_is_non_decreasing = depth_acc <= depth;
if !depth_is_non_decreasing {
println!(
"radix keys: [..., {:#010x}, {:#010x}, {:#010x}, ...]",
u32_muck[entry_idx - 1 - 2],
u32_muck[entry_idx - 1],
u32_muck[entry_idx - 1 + 2],
);
}
assert!(depth_is_non_decreasing, "radix sort, non-decreasing check failed: {} > {}", depth_acc, depth);
depth_acc.max(depth)
});
}
);
}
}
Ok(())
}
}
}),
},
],
));
for (entity, cloud_handle) in gaussian_clouds.iter() {
// TODO: add asset loading indicator (and maybe streamed loading)
if Some(LoadState::Loading) == asset_server.get_load_state(cloud_handle) {
continue;
}
if gaussian_cloud_res.get(cloud_handle).is_none() {
continue;
}
let rangefinder = view.rangefinder3d();
.distance_translation(&mesh_instance.transforms.transform.translation),
let pipeline = pipelines.specialize(&pipeline_cache, &custom_pipeline, key);
// // TODO: distance to gaussian cloud centroid
// let rangefinder = view.rangefinder3d();
transparent_phase.add(Transparent3d {
entity,
draw_function: draw_custom,
distance: 0.0,
// distance: rangefinder
// .distance_translation(&mesh_instance.transforms.transform.translation),
pipeline,
batch_range: 0..1,
dynamic_offset: None,
});
}
}
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.