Giter Site home page Giter Site logo

mosure / bevy_gaussian_splatting Goto Github PK

View Code? Open in Web Editor NEW
113.0 113.0 5.0 173.59 MB

bevy gaussian splatting render pipeline plugin

Home Page: https://mosure.github.io/bevy_gaussian_splatting/index.html?arg1=cactus.gcloud

License: MIT License

Rust 79.90% Dockerfile 0.19% WGSL 19.71% HTML 0.20%
bevy gaussian-splatting particles render-pipeline rust webgl2 webgpu

bevy_gaussian_splatting's People

Contributors

cs50victor avatar mosure avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

Forkers

cs50victor

bevy_gaussian_splatting's Issues

Visual artifact with example scene

Hi. Thank you for creating a Bevy implementation of Gaussian Splatting! ๐Ÿฆ€

I have tried running the example using this command: cargo run --release -- scenes/icecream.gcloud

It compiles and runs but the results looks off (splats too large) and some of the gaussians are flickering badly (even when the camera is not moving):
image

System information (Dell Precision 7560 notebook with Windows 11):

[Display]
DirectX version:	12.0 
GPU processor:		NVIDIA RTX A3000 Laptop GPU
Driver version:		536.45
Driver Type:		DCH
Direct3D feature level:	12_1
CUDA Cores:		4096 
Resizable BAR		Yes
Dynamic Boost 2.0	Yes
WhisperMode 2.0		No
Advanced Optimus	No
Maximum Graphics Power	105 W
Core clock:		1560 MHz 
Memory data rate:	11.00 Gbps
Memory interface:	192-bit 
Memory bandwidth:	264.05 GB/s
Total available graphics memory:	38504 MB
Dedicated video memory:	6144 MB GDDR6
System video memory:	0 MB
Shared system memory:	32360 MB
Video BIOS version:	94.04.42.40.06
IRQ:			Not used
Bus:			PCI Express x8 Gen4
Device ID:		XXX
Part Number:		XXX

[Components]

nvui.dll		8.17.15.3645		NVIDIA User Experience Driver Component
nvxdplcy.dll		8.17.15.3645		NVIDIA User Experience Driver Component
nvxdbat.dll		8.17.15.3645		NVIDIA User Experience Driver Component
nvxdapix.dll		8.17.15.3645		NVIDIA User Experience Driver Component
NVCPL.DLL		8.17.15.3645		NVIDIA User Experience Driver Component
nvCplUIR.dll		8.1.940.0		NVIDIA Control Panel
nvCplUI.exe		8.1.940.0		NVIDIA Control Panel
nvWSSR.dll		31.0.15.3645		NVIDIA Workstation Server
nvWSS.dll		31.0.15.3645		NVIDIA Workstation Server
nvViTvSR.dll		31.0.15.3645		NVIDIA Video Server
nvViTvS.dll		31.0.15.3645		NVIDIA Video Server
nvLicensingS.dll		6.14.15.3645		NVIDIA Licensing Server
nvDevToolSR.dll		31.0.15.3645		NVIDIA Licensing Server
nvDevToolS.dll		31.0.15.3645		NVIDIA 3D Settings Server
nvDispSR.dll		31.0.15.3645		NVIDIA Display Server
nvDispS.dll		31.0.15.3645		NVIDIA Display Server
PhysX		09.21.0713		NVIDIA PhysX
NVCUDA64.DLL		31.0.15.3645		NVIDIA CUDA 12.2.101 driver
nvGameSR.dll		31.0.15.3645		NVIDIA 3D Settings Server
nvGameS.dll		31.0.15.3645		NVIDIA 3D Settings Server

fix transforms - dynamic offset using DynamicUniformIndex

https://api.github.com/mosure/bevy_gaussian_splatting/blob/eedd27e0f32bdf33f324238284c915dc6e1574d2/src/render/mod.rs#L1147

            None => return RenderCommandResult::Failure,
        };

        pass.set_bind_group(2, &bind_groups.cloud_bind_group, &[]);
        pass.set_bind_group(3, &bind_groups.sorted_bind_group, &[]);

        pass.draw_indirect(&gpu_gaussian_cloud.draw_indirect_buffer, 0);

        RenderCommandResult::Success
    }
}





struct RadixSortNode {
    gaussian_clouds: QueryState<(
        &'static Handle<GaussianCloud>,
        &'static GaussianCloudBindGroup
    )>,
    initialized: bool,
    pipeline_idx: Option<u32>,
    view_bind_group: QueryState<(
        &'static GaussianViewBindGroup,
        &'static ViewUniformOffset,
    )>,
}

impl FromWorld for RadixSortNode {
    fn from_world(world: &mut World) -> Self {
        Self {
            gaussian_clouds: world.query(),
            initialized: false,
            pipeline_idx: None,
            view_bind_group: world.query(),
        }
    }
}

impl render_graph::Node for RadixSortNode {
    fn update(&mut self, world: &mut World) {
        let pipeline = world.resource::<GaussianCloudPipeline>();
        let pipeline_cache = world.resource::<PipelineCache>();

        if !self.initialized {
            let mut pipelines_loaded = true;
            for sort_pipeline in pipeline.radix_sort_pipelines.iter() {
                if let CachedPipelineState::Ok(_) =
                        pipeline_cache.get_compute_pipeline_state(*sort_pipeline)
                {
                    continue;
                }

                pipelines_loaded = false;
            }

            self.initialized = pipelines_loaded;

            if !self.initialized {
                return;
            }
        }

        if self.pipeline_idx.is_none() {
            self.pipeline_idx = Some(0);
        } else {
            self.pipeline_idx = Some((self.pipeline_idx.unwrap() + 1) % pipeline.radix_sort_pipelines.len() as u32);
        }

        self.gaussian_clouds.update_archetypes(world);
        self.view_bind_group.update_archetypes(world);
    }

    fn run(
        &self,
        _graph: &mut render_graph::RenderGraphContext,
        render_context: &mut RenderContext,
        world: &World,
    ) -> Result<(), render_graph::NodeRunError> {
        if !self.initialized || self.pipeline_idx.is_none() {
            return Ok(());
        }

        let _idx = self.pipeline_idx.unwrap() as usize; // TODO: temporal sort

        let pipeline_cache = world.resource::<PipelineCache>();
        let pipeline = world.resource::<GaussianCloudPipeline>();
        let gaussian_uniforms = world.resource::<GaussianUniformBindGroups>();

        let command_encoder = render_context.command_encoder();

        for (
            view_bind_group,
            view_uniform_offset,
        ) in self.view_bind_group.iter_manual(world) {
            for (
                cloud_handle,
                cloud_bind_group
            ) in self.gaussian_clouds.iter_manual(world) {
                let cloud = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap().get(cloud_handle).unwrap();

                let radix_digit_places = ShaderDefines::default().radix_digit_places;

                command_encoder.clear_buffer(
                    &cloud.sorting_global_buffer,
                    0,
                    None,
                );

                {
                    let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

                    // TODO: view/global
                    pass.set_bind_group(
                        0,
                        &view_bind_group.value,
                        &[view_uniform_offset.offset],
                    );
                    pass.set_bind_group(
                        1,
                        gaussian_uniforms.base_bind_group.as_ref().unwrap(),
                        &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
                    );
                    pass.set_bind_group(
                        2,
                        &cloud_bind_group.cloud_bind_group,
                        &[]
                    );
                    pass.set_bind_group(
                        3,
                        &cloud_bind_group.radix_sort_bind_groups[1],
                        &[],
                    );

                    let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap();
                    pass.set_pipeline(radix_sort_a);

                    let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a;
                    pass.dispatch_workgroups((cloud.count + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);


                    let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
                    pass.set_pipeline(radix_sort_b);

                    pass.dispatch_workgroups(1, radix_digit_places, 1);
                }

                for pass_idx in 0..radix_digit_places {
                    if pass_idx > 0 {
                        let size = ShaderDefines::default().radix_base * ShaderDefines::default().max_tile_count_c * std::mem::size_of::<u32>() as u32;
                        command_encoder.clear_buffer(
                            &cloud.sorting_global_buffer,
                            0,
                            std::num::NonZeroU64::new(size as u64).unwrap().into()
                        );
                    }

                    let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

                    let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap();
                    pass.set_pipeline(&radix_sort_c);

                    pass.set_bind_group(
                        0,
                        &view_bind_group.value,
                        &[view_uniform_offset.offset],
                    );
                    pass.set_bind_group(
                        1,
                        gaussian_uniforms.base_bind_group.as_ref().unwrap(),
                        &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
                    );
                    pass.set_bind_group(
                        2,
                        &cloud_bind_group.cloud_bind_group,
                        &[]
                    );
                    pass.set_bind_group(
                        3,
                        &cloud_bind_group.radix_sort_bind_groups[pass_idx as usize],
                        &[],
                    );

                    let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c;
                    pass.dispatch_workgroups(1, (cloud.count + workgroup_entries_c - 1) / workgroup_entries_c, 1);
                }
            }
        }


        Ok(())
    }
}

update DrawIndirect buffer during sort phase (GPU sort will override default Dra...

// TODO: update DrawIndirect buffer during sort phase (GPU sort will override default DrawIndirect)

use bevy::{
    prelude::*,
    asset::LoadState,
    utils::Instant,
};

use crate::{
    GaussianCloud,
    GaussianCloudSettings,
    sort::{
        SortedEntries,
        SortMode,
    },
};


#[derive(Default)]
pub struct StdSortPlugin;

impl Plugin for StdSortPlugin {
    fn build(&self, app: &mut App) {
        app.add_systems(Update, std_sort);
    }
}

pub fn std_sort(
    asset_server: Res<AssetServer>,
    gaussian_clouds_res: Res<Assets<GaussianCloud>>,
    mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
    gaussian_clouds: Query<(
        &Handle<GaussianCloud>,
        &Handle<SortedEntries>,
        &GaussianCloudSettings,
    )>,
    cameras: Query<(
        &GlobalTransform,
        &Camera3d,
    )>,
    mut last_camera_position: Local<Vec3>,
    mut last_sort_time: Local<Option<Instant>>,
) {
    let period = std::time::Duration::from_millis(100);
    if let Some(last_sort_time) = last_sort_time.as_ref() {
        if last_sort_time.elapsed() < period {
            return;
        }
    }

    for (
        camera_transform,
        _camera,
    ) in cameras.iter() {
        let camera_position = camera_transform.compute_transform().translation;
        if *last_camera_position == camera_position {
            return;
        }

        for (
            gaussian_cloud_handle,
            sorted_entries_handle,
            settings,
        ) in gaussian_clouds.iter() {
            if settings.sort_mode != SortMode::Std {
                continue;
            }

            if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) {
                continue;
            }

            if Some(LoadState::Loading) == asset_server.get_load_state(sorted_entries_handle) {
                continue;
            }

            if let Some(gaussian_cloud) = gaussian_clouds_res.get(gaussian_cloud_handle) {
                if let Some(sorted_entries) = sorted_entries_res.get_mut(sorted_entries_handle) {
                    assert_eq!(gaussian_cloud.gaussians.len(), sorted_entries.sorted.len());

                    *last_camera_position = camera_position;
                    *last_sort_time = Some(Instant::now());

                    gaussian_cloud.gaussians.iter()
                        .zip(sorted_entries.sorted.iter_mut())
                        .enumerate()
                        .for_each(|(idx, (gaussian, sort_entry))| {
                            let position = Vec3::from_slice(gaussian.position.as_ref());
                            let delta = camera_position - position;

                            sort_entry.key = bytemuck::cast(delta.length_squared());
                            sort_entry.index = idx as u32;
                        });

                    sorted_entries.sorted.sort_unstable_by(|a, b| {
                        bytemuck::cast::<u32, f32>(b.key).partial_cmp(&bytemuck::cast::<u32, f32>(a.key)).unwrap()
                    });

                    // TODO: update DrawIndirect buffer during sort phase (GPU sort will override default DrawIndirect)
                }
            }
        }
    }
}

view/global

https://api.github.com/mosure/bevy_gaussian_splatting/blob/eedd27e0f32bdf33f324238284c915dc6e1574d2/src/render/mod.rs#L1138

            None => return RenderCommandResult::Failure,
        };

        pass.set_bind_group(2, &bind_groups.cloud_bind_group, &[]);
        pass.set_bind_group(3, &bind_groups.sorted_bind_group, &[]);

        pass.draw_indirect(&gpu_gaussian_cloud.draw_indirect_buffer, 0);

        RenderCommandResult::Success
    }
}





struct RadixSortNode {
    gaussian_clouds: QueryState<(
        &'static Handle<GaussianCloud>,
        &'static GaussianCloudBindGroup
    )>,
    initialized: bool,
    pipeline_idx: Option<u32>,
    view_bind_group: QueryState<(
        &'static GaussianViewBindGroup,
        &'static ViewUniformOffset,
    )>,
}

impl FromWorld for RadixSortNode {
    fn from_world(world: &mut World) -> Self {
        Self {
            gaussian_clouds: world.query(),
            initialized: false,
            pipeline_idx: None,
            view_bind_group: world.query(),
        }
    }
}

impl render_graph::Node for RadixSortNode {
    fn update(&mut self, world: &mut World) {
        let pipeline = world.resource::<GaussianCloudPipeline>();
        let pipeline_cache = world.resource::<PipelineCache>();

        if !self.initialized {
            let mut pipelines_loaded = true;
            for sort_pipeline in pipeline.radix_sort_pipelines.iter() {
                if let CachedPipelineState::Ok(_) =
                        pipeline_cache.get_compute_pipeline_state(*sort_pipeline)
                {
                    continue;
                }

                pipelines_loaded = false;
            }

            self.initialized = pipelines_loaded;

            if !self.initialized {
                return;
            }
        }

        if self.pipeline_idx.is_none() {
            self.pipeline_idx = Some(0);
        } else {
            self.pipeline_idx = Some((self.pipeline_idx.unwrap() + 1) % pipeline.radix_sort_pipelines.len() as u32);
        }

        self.gaussian_clouds.update_archetypes(world);
        self.view_bind_group.update_archetypes(world);
    }

    fn run(
        &self,
        _graph: &mut render_graph::RenderGraphContext,
        render_context: &mut RenderContext,
        world: &World,
    ) -> Result<(), render_graph::NodeRunError> {
        if !self.initialized || self.pipeline_idx.is_none() {
            return Ok(());
        }

        let _idx = self.pipeline_idx.unwrap() as usize; // TODO: temporal sort

        let pipeline_cache = world.resource::<PipelineCache>();
        let pipeline = world.resource::<GaussianCloudPipeline>();
        let gaussian_uniforms = world.resource::<GaussianUniformBindGroups>();

        let command_encoder = render_context.command_encoder();

        for (
            view_bind_group,
            view_uniform_offset,
        ) in self.view_bind_group.iter_manual(world) {
            for (
                cloud_handle,
                cloud_bind_group
            ) in self.gaussian_clouds.iter_manual(world) {
                let cloud = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap().get(cloud_handle).unwrap();

                let radix_digit_places = ShaderDefines::default().radix_digit_places;

                command_encoder.clear_buffer(
                    &cloud.sorting_global_buffer,
                    0,
                    None,
                );

                {
                    let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

                    // TODO: view/global
                    pass.set_bind_group(
                        0,
                        &view_bind_group.value,
                        &[view_uniform_offset.offset],
                    );
                    pass.set_bind_group(
                        1,
                        gaussian_uniforms.base_bind_group.as_ref().unwrap(),
                        &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
                    );
                    pass.set_bind_group(
                        2,
                        &cloud_bind_group.cloud_bind_group,
                        &[]
                    );
                    pass.set_bind_group(
                        3,
                        &cloud_bind_group.radix_sort_bind_groups[1],
                        &[],
                    );

                    let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap();
                    pass.set_pipeline(radix_sort_a);

                    let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a;
                    pass.dispatch_workgroups((cloud.count + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);


                    let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
                    pass.set_pipeline(radix_sort_b);

                    pass.dispatch_workgroups(1, radix_digit_places, 1);
                }

                for pass_idx in 0..radix_digit_places {
                    if pass_idx > 0 {
                        let size = ShaderDefines::default().radix_base * ShaderDefines::default().max_tile_count_c * std::mem::size_of::<u32>() as u32;
                        command_encoder.clear_buffer(
                            &cloud.sorting_global_buffer,
                            0,
                            std::num::NonZeroU64::new(size as u64).unwrap().into()
                        );
                    }

                    let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

                    let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap();
                    pass.set_pipeline(&radix_sort_c);

                    pass.set_bind_group(
                        0,
                        &view_bind_group.value,
                        &[view_uniform_offset.offset],
                    );
                    pass.set_bind_group(
                        1,
                        gaussian_uniforms.base_bind_group.as_ref().unwrap(),
                        &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
                    );
                    pass.set_bind_group(
                        2,
                        &cloud_bind_group.cloud_bind_group,
                        &[]
                    );
                    pass.set_bind_group(
                        3,
                        &cloud_bind_group.radix_sort_bind_groups[pass_idx as usize],
                        &[],
                    );

                    let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c;
                    pass.dispatch_workgroups(1, (cloud.count + workgroup_entries_c - 1) / workgroup_entries_c, 1);
                }
            }
        }


        Ok(())
    }
}

keep draw_indirect at the gaussian cloud level

// TODO: keep draw_indirect at the gaussian cloud level

    ) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
        let gaussian_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
            label: Some("gaussian cloud buffer"),
            contents: bytemuck::cast_slice(gaussian_cloud.gaussians.as_slice()),
            usage: BufferUsages::VERTEX | BufferUsages::COPY_DST | BufferUsages::STORAGE,
        });

        let count = gaussian_cloud.gaussians.len();

        // TODO: keep draw_indirect at the gaussian cloud level
        let draw_indirect_buffer = render_device.create_buffer(&BufferDescriptor {
            label: Some("draw indirect buffer"),
            size: std::mem::size_of::<wgpu::util::DrawIndirect>() as u64,

convert higher degree SH to lower degree SH

// TODO: convert higher degree SH to lower degree SH

    fn set_property(&mut self, key: String, property: Property) {
        match (key.as_ref(), property) {
            ("x", Property::Float(v))           => self.position_visibility.position[0] = v,
            ("y", Property::Float(v))           => self.position_visibility.position[1] = v,
            ("z", Property::Float(v))           => self.position_visibility.position[2] = v,
            ("f_dc_0", Property::Float(v))      => self.spherical_harmonic.set(0, v),
            ("f_dc_1", Property::Float(v))      => self.spherical_harmonic.set(1, v),
            ("f_dc_2", Property::Float(v))      => self.spherical_harmonic.set(2, v),
            ("scale_0", Property::Float(v))     => self.scale_opacity.scale[0] = v,
            ("scale_1", Property::Float(v))     => self.scale_opacity.scale[1] = v,
            ("scale_2", Property::Float(v))     => self.scale_opacity.scale[2] = v,
            ("opacity", Property::Float(v))     => self.scale_opacity.opacity = 1.0 / (1.0 + (-v).exp()),
            ("rot_0", Property::Float(v))       => self.rotation.rotation[0] = v,
            ("rot_1", Property::Float(v))       => self.rotation.rotation[1] = v,
            ("rot_2", Property::Float(v))       => self.rotation.rotation[2] = v,
            ("rot_3", Property::Float(v))       => self.rotation.rotation[3] = v,
            (_, Property::Float(v)) if key.starts_with("f_rest_") => {
                let i = key[7..].parse::<usize>().unwrap();

                // interleaved
                // if (i + 3) < SH_COEFF_COUNT {
                //     self.spherical_harmonic.coefficients[i + 3] = v;
                // }

                // planar
                let channel = i / SH_COEFF_COUNT_PER_CHANNEL;
                let coefficient = if SH_COEFF_COUNT_PER_CHANNEL == 1 {
                    1
                } else {
                    (i % (SH_COEFF_COUNT_PER_CHANNEL - 1)) + 1
                };

                let interleaved_idx = coefficient * SH_CHANNELS + channel;

                if interleaved_idx < SH_COEFF_COUNT {
                    self.spherical_harmonic.set(interleaved_idx, v);
                } else {
                    // TODO: convert higher degree SH to lower degree SH
                }
            }
            (_, _) => {},

fix transforms - dynamic offset using DynamicUniformIndex

https://api.github.com/mosure/bevy_gaussian_splatting/blob/eedd27e0f32bdf33f324238284c915dc6e1574d2/src/render/mod.rs#L1147

            None => return RenderCommandResult::Failure,
        };

        pass.set_bind_group(2, &bind_groups.cloud_bind_group, &[]);
        pass.set_bind_group(3, &bind_groups.sorted_bind_group, &[]);

        pass.draw_indirect(&gpu_gaussian_cloud.draw_indirect_buffer, 0);

        RenderCommandResult::Success
    }
}





struct RadixSortNode {
    gaussian_clouds: QueryState<(
        &'static Handle<GaussianCloud>,
        &'static GaussianCloudBindGroup
    )>,
    initialized: bool,
    pipeline_idx: Option<u32>,
    view_bind_group: QueryState<(
        &'static GaussianViewBindGroup,
        &'static ViewUniformOffset,
    )>,
}

impl FromWorld for RadixSortNode {
    fn from_world(world: &mut World) -> Self {
        Self {
            gaussian_clouds: world.query(),
            initialized: false,
            pipeline_idx: None,
            view_bind_group: world.query(),
        }
    }
}

impl render_graph::Node for RadixSortNode {
    fn update(&mut self, world: &mut World) {
        let pipeline = world.resource::<GaussianCloudPipeline>();
        let pipeline_cache = world.resource::<PipelineCache>();

        if !self.initialized {
            let mut pipelines_loaded = true;
            for sort_pipeline in pipeline.radix_sort_pipelines.iter() {
                if let CachedPipelineState::Ok(_) =
                        pipeline_cache.get_compute_pipeline_state(*sort_pipeline)
                {
                    continue;
                }

                pipelines_loaded = false;
            }

            self.initialized = pipelines_loaded;

            if !self.initialized {
                return;
            }
        }

        if self.pipeline_idx.is_none() {
            self.pipeline_idx = Some(0);
        } else {
            self.pipeline_idx = Some((self.pipeline_idx.unwrap() + 1) % pipeline.radix_sort_pipelines.len() as u32);
        }

        self.gaussian_clouds.update_archetypes(world);
        self.view_bind_group.update_archetypes(world);
    }

    fn run(
        &self,
        _graph: &mut render_graph::RenderGraphContext,
        render_context: &mut RenderContext,
        world: &World,
    ) -> Result<(), render_graph::NodeRunError> {
        if !self.initialized || self.pipeline_idx.is_none() {
            return Ok(());
        }

        let _idx = self.pipeline_idx.unwrap() as usize; // TODO: temporal sort

        let pipeline_cache = world.resource::<PipelineCache>();
        let pipeline = world.resource::<GaussianCloudPipeline>();
        let gaussian_uniforms = world.resource::<GaussianUniformBindGroups>();

        let command_encoder = render_context.command_encoder();

        for (
            view_bind_group,
            view_uniform_offset,
        ) in self.view_bind_group.iter_manual(world) {
            for (
                cloud_handle,
                cloud_bind_group
            ) in self.gaussian_clouds.iter_manual(world) {
                let cloud = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap().get(cloud_handle).unwrap();

                let radix_digit_places = ShaderDefines::default().radix_digit_places;

                command_encoder.clear_buffer(
                    &cloud.sorting_global_buffer,
                    0,
                    None,
                );

                {
                    let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

                    // TODO: view/global
                    pass.set_bind_group(
                        0,
                        &view_bind_group.value,
                        &[view_uniform_offset.offset],
                    );
                    pass.set_bind_group(
                        1,
                        gaussian_uniforms.base_bind_group.as_ref().unwrap(),
                        &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
                    );
                    pass.set_bind_group(
                        2,
                        &cloud_bind_group.cloud_bind_group,
                        &[]
                    );
                    pass.set_bind_group(
                        3,
                        &cloud_bind_group.radix_sort_bind_groups[1],
                        &[],
                    );

                    let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap();
                    pass.set_pipeline(radix_sort_a);

                    let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a;
                    pass.dispatch_workgroups((cloud.count + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);


                    let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
                    pass.set_pipeline(radix_sort_b);

                    pass.dispatch_workgroups(1, radix_digit_places, 1);
                }

                for pass_idx in 0..radix_digit_places {
                    if pass_idx > 0 {
                        let size = ShaderDefines::default().radix_base * ShaderDefines::default().max_tile_count_c * std::mem::size_of::<u32>() as u32;
                        command_encoder.clear_buffer(
                            &cloud.sorting_global_buffer,
                            0,
                            std::num::NonZeroU64::new(size as u64).unwrap().into()
                        );
                    }

                    let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

                    let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap();
                    pass.set_pipeline(&radix_sort_c);

                    pass.set_bind_group(
                        0,
                        &view_bind_group.value,
                        &[view_uniform_offset.offset],
                    );
                    pass.set_bind_group(
                        1,
                        gaussian_uniforms.base_bind_group.as_ref().unwrap(),
                        &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
                    );
                    pass.set_bind_group(
                        2,
                        &cloud_bind_group.cloud_bind_group,
                        &[]
                    );
                    pass.set_bind_group(
                        3,
                        &cloud_bind_group.radix_sort_bind_groups[pass_idx as usize],
                        &[],
                    );

                    let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c;
                    pass.dispatch_workgroups(1, (cloud.count + workgroup_entries_c - 1) / workgroup_entries_c, 1);
                }
            }
        }


        Ok(())
    }
}

consider switching to fast (or support multiple options), default is a bit slow

https://api.github.com/mosure/bevy_gaussian_splatting/blob/50685245da00f48ce0d49fe9b05fab65cd2fb149/tools/ply_to_gcloud.rs#L35

use bincode2::serialize_into;
use flate2::{
    Compression,
    write::GzEncoder,
};

use bevy_gaussian_splatting::{
    GaussianCloud,
    ply::parse_ply,
};


fn main() {
    let filename = std::env::args().nth(1).expect("no filename given");

    println!("converting {}", filename);

    // filepath to BufRead
    let file = std::fs::File::open(&filename).expect("failed to open file");
    let mut reader = std::io::BufReader::new(file);

    let cloud = GaussianCloud(parse_ply(&mut reader).expect("failed to parse ply file"));

    // write cloud to .gcloud file (remove .ply)
    let base_filename = filename.split('.').next().expect("no extension").to_string();
    let gcloud_filename = base_filename + ".gcloud";
    // let gcloud_file = std::fs::File::create(&gcloud_filename).expect("failed to create file");
    // let mut writer = std::io::BufWriter::new(gcloud_file);

    // serialize_into(&mut writer, &cloud).expect("failed to encode cloud");

    // write gloud.gz
    let gz_file = std::fs::File::create(&gcloud_filename).expect("failed to create file");
    let mut gz_writer = std::io::BufWriter::new(gz_file);
    let mut gz_encoder = GzEncoder::new(&mut gz_writer, Compression::default());  // TODO: consider switching to fast (or support multiple options), default is a bit slow
    serialize_into(&mut gz_encoder, &cloud).expect("failed to encode cloud");
}

example asset

would it be possible to provide scenes/test.ply here?

add more particle system functionality (e.g. lifetime, color)

// TODO: add more particle system functionality (e.g. lifetime, color)

use rand::{
    prelude::Distribution,
    Rng,
};
use std::marker::Copy;

use bevy::{
    prelude::*,
    reflect::TypeUuid,
    render::render_resource::ShaderType,
};
use bytemuck::{
    Pod,
    Zeroable,
};
use serde::{
    Deserialize,
    Serialize,
};


// TODO: add more particle system functionality (e.g. lifetime, color)
#[derive(
    Clone,
    Debug,
    Copy,
    PartialEq,
    Reflect,
    ShaderType,
    Pod,
    Zeroable,
    Serialize,
    Deserialize,
)]
#[repr(C)]
pub struct ParticleBehavior {
    pub indicies: [u32; 4],
    pub velocity: [f32; 4],
    pub acceleration: [f32; 4],
    pub jerk: [f32; 4],
}

impl Default for ParticleBehavior {
    fn default() -> Self {
        Self {
            indicies: [0, 0, 0, 0],
            velocity: [0.0, 0.0, 0.0, 0.0],
            acceleration: [0.0, 0.0, 0.0, 0.0],
            jerk: [0.0, 0.0, 0.0, 0.0],
        }
    }
}

#[derive(
    Asset,
    Clone,
    Debug,
    Default,
    PartialEq,
    Reflect,
    TypeUuid,
    Serialize,
    Deserialize,
)]
#[uuid = "ac2f08eb-6463-2131-6772-51571ea332d5"]
pub struct ParticleBehaviors(pub Vec<ParticleBehavior>);


impl Distribution<ParticleBehavior> for rand::distributions::Standard {
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> ParticleBehavior {
        ParticleBehavior {
            acceleration: [
                rng.gen_range(-0.01..0.01),
                rng.gen_range(-0.01..0.01),
                rng.gen_range(-0.01..0.01),
                rng.gen_range(-0.01..0.01),
            ],
            jerk: [
                rng.gen_range(-0.0001..0.0001),
                rng.gen_range(-0.0001..0.0001),
                rng.gen_range(-0.0001..0.0001),
                rng.gen_range(-0.0001..0.0001),
            ],
            velocity: [
                rng.gen_range(-1.0..1.0),
                rng.gen_range(-1.0..1.0),
                rng.gen_range(-1.0..1.0),
                rng.gen_range(-1.0..1.0),
            ],
            ..Default::default()
        }
    }
}

pub fn random_particle_behaviors(n: usize) -> ParticleBehaviors {
    let mut rng = rand::thread_rng();
    let mut behaviors = Vec::with_capacity(n);
    for i in 0..n {
        let mut behavior: ParticleBehavior = rng.gen();
        behavior.indicies[0] = i as u32;
        behaviors.push(behavior);
    }

    ParticleBehaviors(behaviors)
}

add RenderAssetPlugin for SortedEntries & auto-insert to GaussianCloudBundles if...

supports pre-sorting or CPU sorting in main world, initializes the sorting_entry_buffer

// TODO: add RenderAssetPlugin for SortedEntries & auto-insert to GaussianCloudBundles if their sort mode is not None

use bevy::{
    prelude::*,
    asset::LoadState,
    ecs::system::{
        lifetimeless::SRes,
        SystemParamItem,
    },
    reflect::TypeUuid,
    render::{
        render_asset::{
            RenderAsset,
            RenderAssetPlugin,
            PrepareAssetError,
        },
        render_resource::*,
        renderer::RenderDevice,
    },
};
use bytemuck::{
    Pod,
    Zeroable,
};
use static_assertions::assert_cfg;

use crate::{
    GaussianCloud,
    GaussianCloudSettings,
};


#[cfg(feature = "sort_radix")]
pub mod radix;

#[cfg(feature = "sort_rayon")]
pub mod rayon;


assert_cfg!(
    any(
        feature = "sort_radix",
        feature = "sort_rayon",
    ),
    "no sort mode enabled",
);


#[derive(
    Component,
    Debug,
    Clone,
    PartialEq,
    Reflect,
)]
pub enum SortMode {
    None,

    #[cfg(feature = "sort_radix")]
    Radix,

    #[cfg(feature = "sort_rayon")]
    Rayon,
}

impl Default for SortMode {
    #[allow(unreachable_code)]
    fn default() -> Self {
        #[cfg(feature = "sort_rayon")]
        return Self::Rayon;

        #[cfg(feature = "sort_radix")]
        return Self::Radix;

        Self::None
    }
}


#[derive(Default)]
pub struct SortPlugin;

impl Plugin for SortPlugin {
    fn build(&self, app: &mut App) {
        #[cfg(feature = "sort_radix")]
        app.add_plugins(radix::RadixSortPlugin);

        #[cfg(feature = "sort_rayon")]
        app.add_plugins(rayon::RayonSortPlugin);


        app.register_type::<SortedEntries>();
        app.init_asset::<SortedEntries>();
        app.register_asset_reflect::<SortedEntries>();

        app.add_plugins(RenderAssetPlugin::<SortedEntries>::default());

        app.add_systems(Update, auto_insert_sorted_entries);
    }
}


fn auto_insert_sorted_entries(
    mut commands: Commands,
    asset_server: Res<AssetServer>,
    gaussian_clouds_res: Res<Assets<GaussianCloud>>,
    mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
    gaussian_clouds: Query<(
        Entity,
        &Handle<GaussianCloud>,
        &GaussianCloudSettings,
        Without<Handle<SortedEntries>>,
    )>,
) {
    for (
        entity,
        gaussian_cloud_handle,
        _settings,
        _,
    ) in gaussian_clouds.iter() {
        // // TODO: specialize vertex shader for sort mode (e.g. draw_indirect but no sort indirection)
        // if settings.sort_mode == SortMode::None {
        //     continue;
        // }

        if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) {
            continue;
        }

        let cloud = gaussian_clouds_res.get(gaussian_cloud_handle);
        if cloud.is_none() {
            continue;
        }
        let cloud = cloud.unwrap();

        // TODO: move gaussian_cloud and sorted_entry assets into an asset bundle
        let sorted_entries = sorted_entries_res.add(SortedEntries {
            sorted: (0..cloud.gaussians.len())
                .map(|idx| {
                    SortEntry {
                        key: 1,
                        index: idx as u32,
                    }
                })
                .collect(),
        });

        commands.entity(entity)
            .insert(sorted_entries);
    }
}


#[derive(
    Clone,
    Copy,
    Debug,
    Default,
    PartialEq,
    Reflect,
    ShaderType,
    Pod,
    Zeroable,
)]
#[repr(C)]
pub struct SortEntry {
    pub key: u32,
    pub index: u32,
}

// TODO: add RenderAssetPlugin for SortedEntries & auto-insert to GaussianCloudBundles if their sort mode is not None
// supports pre-sorting or CPU sorting in main world, initializes the sorting_entry_buffer
#[derive(
    Clone,
    Asset,
    Debug,
    Default,
    PartialEq,
    Reflect,
    TypeUuid,
)]
#[uuid = "ac2f08eb-fa13-ccdd-ea11-51571ea332d5"]
pub struct SortedEntries {
    pub sorted: Vec<SortEntry>,
}

impl RenderAsset for SortedEntries {
    type ExtractedAsset = SortedEntries;
    type PreparedAsset = GpuSortedEntry;
    type Param = SRes<RenderDevice>;

    fn extract_asset(&self) -> Self::ExtractedAsset {
        self.clone()
    }

    fn prepare_asset(
        sorted_entries: Self::ExtractedAsset,
        render_device: &mut SystemParamItem<Self::Param>,
    ) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
        let sorted_entry_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
            label: Some("sorted_entry_buffer"),
            contents: bytemuck::cast_slice(sorted_entries.sorted.as_slice()),
            usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST | BufferUsages::STORAGE,
        });

        let count = sorted_entries.sorted.len();

        Ok(GpuSortedEntry {
            sorted_entry_buffer,
            count,
        })
    }
}


// TODO: support instancing and multiple cameras
//       separate entry_buffer_a binding into unique a bind group to optimize buffer updates
#[derive(Debug, Clone)]
pub struct GpuSortedEntry {
    pub sorted_entry_buffer: Buffer,
    pub count: usize,
}

support instancing and multiple cameras

separate entry_buffer_a binding into unique a bind group to optimize buffer updates

// TODO: support instancing and multiple cameras

use bevy::{
    prelude::*,
    asset::LoadState,
    ecs::system::{
        lifetimeless::SRes,
        SystemParamItem,
    },
    reflect::TypeUuid,
    render::{
        render_asset::{
            RenderAsset,
            RenderAssetPlugin,
            PrepareAssetError,
        },
        render_resource::*,
        renderer::RenderDevice,
    },
};
use bytemuck::{
    Pod,
    Zeroable,
};
use static_assertions::assert_cfg;

use crate::{
    GaussianCloud,
    GaussianCloudSettings,
};


#[cfg(feature = "sort_radix")]
pub mod radix;

#[cfg(feature = "sort_rayon")]
pub mod rayon;


assert_cfg!(
    any(
        feature = "sort_radix",
        feature = "sort_rayon",
    ),
    "no sort mode enabled",
);


#[derive(
    Component,
    Debug,
    Clone,
    PartialEq,
    Reflect,
)]
pub enum SortMode {
    None,

    #[cfg(feature = "sort_radix")]
    Radix,

    #[cfg(feature = "sort_rayon")]
    Rayon,
}

impl Default for SortMode {
    #[allow(unreachable_code)]
    fn default() -> Self {
        #[cfg(feature = "sort_rayon")]
        return Self::Rayon;

        #[cfg(feature = "sort_radix")]
        return Self::Radix;

        Self::None
    }
}


#[derive(Default)]
pub struct SortPlugin;

impl Plugin for SortPlugin {
    fn build(&self, app: &mut App) {
        #[cfg(feature = "sort_radix")]
        app.add_plugins(radix::RadixSortPlugin);

        #[cfg(feature = "sort_rayon")]
        app.add_plugins(rayon::RayonSortPlugin);


        app.register_type::<SortedEntries>();
        app.init_asset::<SortedEntries>();
        app.register_asset_reflect::<SortedEntries>();

        app.add_plugins(RenderAssetPlugin::<SortedEntries>::default());

        app.add_systems(Update, auto_insert_sorted_entries);
    }
}


fn auto_insert_sorted_entries(
    mut commands: Commands,
    asset_server: Res<AssetServer>,
    gaussian_clouds_res: Res<Assets<GaussianCloud>>,
    mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
    gaussian_clouds: Query<(
        Entity,
        &Handle<GaussianCloud>,
        &GaussianCloudSettings,
        Without<Handle<SortedEntries>>,
    )>,
) {
    for (
        entity,
        gaussian_cloud_handle,
        _settings,
        _,
    ) in gaussian_clouds.iter() {
        // // TODO: specialize vertex shader for sort mode (e.g. draw_indirect but no sort indirection)
        // if settings.sort_mode == SortMode::None {
        //     continue;
        // }

        if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) {
            continue;
        }

        let cloud = gaussian_clouds_res.get(gaussian_cloud_handle);
        if cloud.is_none() {
            continue;
        }
        let cloud = cloud.unwrap();

        // TODO: move gaussian_cloud and sorted_entry assets into an asset bundle
        let sorted_entries = sorted_entries_res.add(SortedEntries {
            sorted: (0..cloud.gaussians.len())
                .map(|idx| {
                    SortEntry {
                        key: 1,
                        index: idx as u32,
                    }
                })
                .collect(),
        });

        commands.entity(entity)
            .insert(sorted_entries);
    }
}


#[derive(
    Clone,
    Copy,
    Debug,
    Default,
    PartialEq,
    Reflect,
    ShaderType,
    Pod,
    Zeroable,
)]
#[repr(C)]
pub struct SortEntry {
    pub key: u32,
    pub index: u32,
}

// TODO: add RenderAssetPlugin for SortedEntries & auto-insert to GaussianCloudBundles if their sort mode is not None
// supports pre-sorting or CPU sorting in main world, initializes the sorting_entry_buffer
#[derive(
    Clone,
    Asset,
    Debug,
    Default,
    PartialEq,
    Reflect,
    TypeUuid,
)]
#[uuid = "ac2f08eb-fa13-ccdd-ea11-51571ea332d5"]
pub struct SortedEntries {
    pub sorted: Vec<SortEntry>,
}

impl RenderAsset for SortedEntries {
    type ExtractedAsset = SortedEntries;
    type PreparedAsset = GpuSortedEntry;
    type Param = SRes<RenderDevice>;

    fn extract_asset(&self) -> Self::ExtractedAsset {
        self.clone()
    }

    fn prepare_asset(
        sorted_entries: Self::ExtractedAsset,
        render_device: &mut SystemParamItem<Self::Param>,
    ) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
        let sorted_entry_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
            label: Some("sorted_entry_buffer"),
            contents: bytemuck::cast_slice(sorted_entries.sorted.as_slice()),
            usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST | BufferUsages::STORAGE,
        });

        let count = sorted_entries.sorted.len();

        Ok(GpuSortedEntry {
            sorted_entry_buffer,
            count,
        })
    }
}


// TODO: support instancing and multiple cameras
//       separate entry_buffer_a binding into unique a bind group to optimize buffer updates
#[derive(Debug, Clone)]
pub struct GpuSortedEntry {
    pub sorted_entry_buffer: Buffer,
    pub count: usize,
}

allow swapping of sort backends

requires GaussianCloud RenderAsset dependency

// TODO: allow swapping of sort backends

use bevy::{
    prelude::*,
    asset::{
        load_internal_asset,
        LoadState,
    },
    core_pipeline::core_3d::CORE_3D,
    ecs::system::{
        lifetimeless::SRes,
        SystemParamItem,
    },
    render::{
        render_asset::RenderAssets,
        render_resource::*,
        renderer::{
            RenderContext,
            RenderDevice,
        },
        render_graph::{
            Node,
            NodeRunError,
            RenderGraphApp,
            RenderGraphContext,
        },
        Render,
        RenderApp,
        RenderSet,
        view::ViewUniformOffset,
    },
};

use crate::{
    gaussian::GaussianCloud,
    render::{
        GaussianCloudBindGroup,
        GaussianCloudPipeline,
        GaussianUniformBindGroups,
        GaussianViewBindGroup,
        ShaderDefines,
        shader_defs,
    },
};


const RADIX_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(6234673214);
const TEMPORAL_SORT_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(1634543224);

pub mod node {
    pub const RADIX_SORT: &str = "radix_sort";
}


#[derive(Default)]
pub struct RadixSortPlugin;

impl Plugin for RadixSortPlugin {
    fn build(&self, app: &mut App) {
        load_internal_asset!(
            app,
            RADIX_SHADER_HANDLE,
            "radix.wgsl",
            Shader::from_wgsl
        );

        load_internal_asset!(
            app,
            TEMPORAL_SORT_SHADER_HANDLE,
            "temporal.wgsl",
            Shader::from_wgsl
        );

        if let Ok(render_app) = app.get_sub_app_mut(RenderApp) {
            render_app
                .add_render_graph_node::<RadixSortNode>(
                    CORE_3D,
                    node::RADIX_SORT,
                )
                .add_render_graph_edge(
                    CORE_3D,
                    node::RADIX_SORT,
                     bevy::core_pipeline::core_3d::graph::node::PREPASS,
                );

            render_app
                .add_systems(
                    Render,
                    (
                        queue_radix_bind_group.in_set(RenderSet::QueueMeshes),
                    ),
                );
        }
    }

    fn finish(&self, app: &mut App) {
        if let Ok(render_app) = app.get_sub_app_mut(RenderApp) {
            render_app
                .init_resource::<RadixSortPipeline>();
        }
    }
}


// TODO: allow swapping of sort backends
//      requires GaussianCloud RenderAsset dependency
#[derive(Debug, Clone)]
pub struct GpuRadixBuffers {
    pub sorting_global_buffer: Buffer,
    pub sorting_status_counter_buffer: Buffer,
    pub sorting_pass_buffers: [Buffer; 4],
    pub entry_buffer_a: Buffer,
    pub entry_buffer_b: Buffer,
}
impl GpuRadixBuffers {
    pub fn new(
        count: usize,
        render_device: &mut SystemParamItem<SRes<RenderDevice>>,
    ) -> Self {
        let sorting_global_buffer = render_device.create_buffer(&BufferDescriptor {
            label: Some("sorting global buffer"),
            size: ShaderDefines::default().sorting_buffer_size as u64,
            usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
            mapped_at_creation: false,
        });

        let sorting_status_counter_buffer = render_device.create_buffer(&BufferDescriptor {
            label: Some("status counters buffer"),
            size: ShaderDefines::default().sorting_status_counters_buffer_size(count) as u64,
            usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
            mapped_at_creation: false,
        });

        let sorting_pass_buffers = (0..4)
            .map(|idx| {
                render_device.create_buffer_with_data(&BufferInitDescriptor {
                    label: format!("sorting pass buffer {}", idx).as_str().into(),
                    contents: &[idx as u8, 0, 0, 0],
                    usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
                })
            })
            .collect::<Vec<Buffer>>()
            .try_into()
            .unwrap();

        let entry_buffer_a = render_device.create_buffer(&BufferDescriptor {
            label: Some("entry buffer a"),
            size: (count * std::mem::size_of::<(u32, u32)>()) as u64,
            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
            mapped_at_creation: false,
        });

        let entry_buffer_b = render_device.create_buffer(&BufferDescriptor {
            label: Some("entry buffer b"),
            size: (count * std::mem::size_of::<(u32, u32)>()) as u64,
            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
            mapped_at_creation: false,
        });

        GpuRadixBuffers {
            sorting_global_buffer,
            sorting_status_counter_buffer,
            sorting_pass_buffers,
            entry_buffer_a,
            entry_buffer_b,
        }
    }
}


#[derive(Resource)]
pub struct RadixSortPipeline {
    pub radix_sort_layout: BindGroupLayout,
    pub radix_sort_pipelines: [CachedComputePipelineId; 3],
}

impl FromWorld for RadixSortPipeline {
    fn from_world(render_world: &mut World) -> Self {
        let render_device = render_world.resource::<RenderDevice>();
        let gaussian_cloud_pipeline = render_world.resource::<GaussianCloudPipeline>();

        let sorting_buffer_entry = BindGroupLayoutEntry {
            binding: 1,
            visibility: ShaderStages::COMPUTE,
            ty: BindingType::Buffer {
                ty: BufferBindingType::Storage { read_only: false },
                has_dynamic_offset: false,
                min_binding_size: BufferSize::new(ShaderDefines::default().sorting_buffer_size as u64),
            },
            count: None,
        };

        let sorting_status_counters_buffer_entry = BindGroupLayoutEntry {
            binding: 2,
            visibility: ShaderStages::COMPUTE,
            ty: BindingType::Buffer {
                ty: BufferBindingType::Storage { read_only: false },
                has_dynamic_offset: false,
                min_binding_size: BufferSize::new(ShaderDefines::default().sorting_status_counters_buffer_size(1) as u64),
            },
            count: None,
        };

        let draw_indirect_buffer_entry = BindGroupLayoutEntry {
            binding: 3,
            visibility: ShaderStages::COMPUTE,
            ty: BindingType::Buffer {
                ty: BufferBindingType::Storage { read_only: false },
                has_dynamic_offset: false,
                min_binding_size: BufferSize::new(std::mem::size_of::<wgpu::util::DrawIndirect>() as u64),
            },
            count: None,
        };

        let radix_sort_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
            label: Some("radix_sort_layout"),
            entries: &[
                BindGroupLayoutEntry {
                    binding: 0,
                    visibility: ShaderStages::COMPUTE,
                    ty: BindingType::Buffer {
                        ty: BufferBindingType::Uniform,
                        has_dynamic_offset: false,
                        min_binding_size: BufferSize::new(std::mem::size_of::<u32>() as u64),
                    },
                    count: None,
                },
                sorting_buffer_entry,
                sorting_status_counters_buffer_entry,
                draw_indirect_buffer_entry,
                BindGroupLayoutEntry {
                    binding: 4,
                    visibility: ShaderStages::COMPUTE,
                    ty: BindingType::Buffer {
                        ty: BufferBindingType::Storage { read_only: false },
                        has_dynamic_offset: false,
                        min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64),
                    },
                    count: None,
                },
                BindGroupLayoutEntry {
                    binding: 5,
                    visibility: ShaderStages::COMPUTE,
                    ty: BindingType::Buffer {
                        ty: BufferBindingType::Storage { read_only: false },
                        has_dynamic_offset: false,
                        min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64),
                    },
                    count: None,
                },
            ],
        });

        let sorting_layout = vec![
            gaussian_cloud_pipeline.view_layout.clone(),
            gaussian_cloud_pipeline.gaussian_uniform_layout.clone(),
            gaussian_cloud_pipeline.gaussian_cloud_layout.clone(),
            radix_sort_layout.clone(),
        ];
        let shader_defs = shader_defs(false, false);

        let pipeline_cache = render_world.resource::<PipelineCache>();
        let radix_sort_a = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
            label: Some("radix_sort_a".into()),
            layout: sorting_layout.clone(),
            push_constant_ranges: vec![],
            shader: RADIX_SHADER_HANDLE,
            shader_defs: shader_defs.clone(),
            entry_point: "radix_sort_a".into(),
        });

        let radix_sort_b = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
            label: Some("radix_sort_b".into()),
            layout: sorting_layout.clone(),
            push_constant_ranges: vec![],
            shader: RADIX_SHADER_HANDLE,
            shader_defs: shader_defs.clone(),
            entry_point: "radix_sort_b".into(),
        });

        let radix_sort_c = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
            label: Some("radix_sort_c".into()),
            layout: sorting_layout.clone(),
            push_constant_ranges: vec![],
            shader: RADIX_SHADER_HANDLE,
            shader_defs: shader_defs.clone(),
            entry_point: "radix_sort_c".into(),
        });

        RadixSortPipeline {
            radix_sort_layout,
            radix_sort_pipelines: [
                radix_sort_a,
                radix_sort_b,
                radix_sort_c,
            ],
        }
    }
}



#[derive(Component)]
pub struct RadixBindGroup {
    pub radix_sort_bind_groups: [BindGroup; 4],
}

pub fn queue_radix_bind_group(
    mut commands: Commands,
    radix_pipeline: Res<RadixSortPipeline>,
    render_device: Res<RenderDevice>,
    asset_server: Res<AssetServer>,
    gaussian_cloud_res: Res<RenderAssets<GaussianCloud>>,
    gaussian_clouds: Query<(
        Entity,
        &Handle<GaussianCloud>,
    )>,
) {
    for (entity, cloud_handle) in gaussian_clouds.iter() {
        if Some(LoadState::Loading) == asset_server.get_load_state(cloud_handle) {
            continue;
        }

        if gaussian_cloud_res.get(cloud_handle).is_none() {
            continue;
        }

        let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();

        let sorting_global_entry = BindGroupEntry {
            binding: 1,
            resource: BindingResource::Buffer(BufferBinding {
                buffer: &cloud.radix_sort_buffers.sorting_global_buffer,
                offset: 0,
                size: BufferSize::new(cloud.radix_sort_buffers.sorting_global_buffer.size()),
            }),
        };

        let sorting_status_counters_entry = BindGroupEntry {
            binding: 2,
            resource: BindingResource::Buffer(BufferBinding {
                buffer: &cloud.radix_sort_buffers.sorting_status_counter_buffer,
                offset: 0,
                size: BufferSize::new(cloud.radix_sort_buffers.sorting_status_counter_buffer.size()),
            }),
        };

        let draw_indirect_entry = BindGroupEntry {
            binding: 3,
            resource: BindingResource::Buffer(BufferBinding {
                buffer: &cloud.draw_indirect_buffer,
                offset: 0,
                size: BufferSize::new(cloud.draw_indirect_buffer.size()),
            }),
        };

        let radix_sort_bind_groups: [BindGroup; 4] = (0..4)
            .map(|idx| {
                render_device.create_bind_group(
                    format!("radix_sort_bind_group {}", idx).as_str(),
                    &radix_pipeline.radix_sort_layout,
                    &[
                        BindGroupEntry {
                            binding: 0,
                            resource: BindingResource::Buffer(BufferBinding {
                                buffer: &cloud.radix_sort_buffers.sorting_pass_buffers[idx],
                                offset: 0,
                                size: BufferSize::new(std::mem::size_of::<u32>() as u64),
                            }),
                        },
                        sorting_global_entry.clone(),
                        sorting_status_counters_entry.clone(),
                        draw_indirect_entry.clone(),
                        BindGroupEntry {
                            binding: 4,
                            resource: BindingResource::Buffer(BufferBinding {
                                buffer: if idx % 2 == 0 {
                                    &cloud.radix_sort_buffers.entry_buffer_a
                                } else {
                                    &cloud.radix_sort_buffers.entry_buffer_b
                                },
                                offset: 0,
                                size: BufferSize::new((cloud.count as usize * std::mem::size_of::<(u32, u32)>()) as u64),
                            }),
                        },
                        BindGroupEntry {
                            binding: 5,
                            resource: BindingResource::Buffer(BufferBinding {
                                buffer: if idx % 2 == 0 {
                                    &cloud.radix_sort_buffers.entry_buffer_b
                                } else {
                                    &cloud.radix_sort_buffers.entry_buffer_a
                                },
                                offset: 0,
                                size: BufferSize::new((cloud.count as usize * std::mem::size_of::<(u32, u32)>()) as u64),
                            }),
                        },
                    ],
                )
            })
            .collect::<Vec<BindGroup>>()
            .try_into()
            .unwrap();

        commands.entity(entity).insert(RadixBindGroup {
            radix_sort_bind_groups,
        });
    }
}






pub struct RadixSortNode {
    gaussian_clouds: QueryState<(
        &'static Handle<GaussianCloud>,
        &'static GaussianCloudBindGroup,
        &'static RadixBindGroup,
    )>,
    initialized: bool,
    view_bind_group: QueryState<(
        &'static GaussianViewBindGroup,
        &'static ViewUniformOffset,
    )>,
}

impl FromWorld for RadixSortNode {
    fn from_world(world: &mut World) -> Self {
        Self {
            gaussian_clouds: world.query(),
            initialized: false,
            view_bind_group: world.query(),
        }
    }
}

impl Node for RadixSortNode {
    fn update(&mut self, world: &mut World) {
        let pipeline = world.resource::<RadixSortPipeline>();
        let pipeline_cache = world.resource::<PipelineCache>();

        if !self.initialized {
            let mut pipelines_loaded = true;
            for sort_pipeline in pipeline.radix_sort_pipelines.iter() {
                if let CachedPipelineState::Ok(_) =
                        pipeline_cache.get_compute_pipeline_state(*sort_pipeline)
                {
                    continue;
                }

                pipelines_loaded = false;
            }

            self.initialized = pipelines_loaded;

            if !self.initialized {
                return;
            }
        }

        self.gaussian_clouds.update_archetypes(world);
        self.view_bind_group.update_archetypes(world);
    }

    fn run(
        &self,
        _graph: &mut RenderGraphContext,
        render_context: &mut RenderContext,
        world: &World,
    ) -> Result<(), NodeRunError> {
        if !self.initialized {
            return Ok(());
        }

        let pipeline_cache = world.resource::<PipelineCache>();
        let pipeline = world.resource::<RadixSortPipeline>();
        let gaussian_uniforms = world.resource::<GaussianUniformBindGroups>();

        let command_encoder = render_context.command_encoder();

        for (
            view_bind_group,
            view_uniform_offset,
        ) in self.view_bind_group.iter_manual(world) {
            for (
                cloud_handle,
                cloud_bind_group,
                radix_bind_group,
            ) in self.gaussian_clouds.iter_manual(world) {
                let cloud = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap().get(cloud_handle).unwrap();

                let radix_digit_places = ShaderDefines::default().radix_digit_places;

                {
                    command_encoder.clear_buffer(
                        &cloud.radix_sort_buffers.sorting_global_buffer,
                        0,
                        None,
                    );

                    command_encoder.clear_buffer(
                        &cloud.radix_sort_buffers.sorting_status_counter_buffer,
                        0,
                        None,
                    );

                    command_encoder.clear_buffer(
                        &cloud.draw_indirect_buffer,
                        0,
                        None,
                    );
                }

                {
                    let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

                    pass.set_bind_group(
                        0,
                        &view_bind_group.value,
                        &[view_uniform_offset.offset],
                    );
                    pass.set_bind_group(
                        1,
                        gaussian_uniforms.base_bind_group.as_ref().unwrap(),
                        &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
                    );
                    pass.set_bind_group(
                        2,
                        &cloud_bind_group.cloud_bind_group,
                        &[]
                    );
                    pass.set_bind_group(
                        3,
                        &radix_bind_group.radix_sort_bind_groups[1],
                        &[],
                    );

                    let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap();
                    pass.set_pipeline(radix_sort_a);

                    let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a;
                    pass.dispatch_workgroups((cloud.count as u32 + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);


                    let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
                    pass.set_pipeline(radix_sort_b);

                    pass.dispatch_workgroups(1, radix_digit_places, 1);
                }

                for pass_idx in 0..radix_digit_places {
                    if pass_idx > 0 {
                        command_encoder.clear_buffer(
                            &cloud.radix_sort_buffers.sorting_status_counter_buffer,
                            0,
                            None,
                        );
                    }

                    let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

                    let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap();
                    pass.set_pipeline(&radix_sort_c);

                    pass.set_bind_group(
                        0,
                        &view_bind_group.value,
                        &[view_uniform_offset.offset],
                    );
                    pass.set_bind_group(
                        1,
                        gaussian_uniforms.base_bind_group.as_ref().unwrap(),
                        &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
                    );
                    pass.set_bind_group(
                        2,
                        &cloud_bind_group.cloud_bind_group,
                        &[]
                    );
                    pass.set_bind_group(
                        3,
                        &radix_bind_group.radix_sort_bind_groups[pass_idx as usize],
                        &[],
                    );

                    let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c;
                    pass.dispatch_workgroups(1, (cloud.count as u32 + workgroup_entries_c - 1) / workgroup_entries_c, 1);
                }
            }
        }

        Ok(())
    }
}

follow bevy's crate structure

// TODO: move to editor crate

use bevy::{
    prelude::*,
    app::AppExit,
    core::Name,
};
use bevy_inspector_egui::quick::WorldInspectorPlugin;
use bevy_panorbit_camera::{
    PanOrbitCamera,
    PanOrbitCameraPlugin,
};

use bevy_gaussian_splatting::{
    Gaussian,
    GaussianCloud,
    GaussianCloudSettings,
    GaussianSplattingBundle,
    GaussianSplattingPlugin,
    utils::setup_hooks, SphericalHarmonicCoefficients,
};


// TODO: move to editor crate
pub struct GaussianSplattingViewer {
    pub editor: bool,
    pub esc_close: bool,
    pub show_fps: bool,
    pub width: f32,
    pub height: f32,
    pub name: String,
}

impl Default for GaussianSplattingViewer {
    fn default() -> GaussianSplattingViewer {
        GaussianSplattingViewer {
            editor: true,
            esc_close: true,
            show_fps: true,
            width: 1920.0,
            height: 1080.0,
            name: "bevy_gaussian_splatting".to_string(),
        }
    }
}


pub fn setup_aabb_obb_compare(
    mut commands: Commands,
    mut gaussian_assets: ResMut<Assets<GaussianCloud>>,
) {
    let mut blue_sh = SphericalHarmonicCoefficients::default();
    blue_sh.coefficients[2] = 5.0;

    let blue_aabb_gaussian = Gaussian {
        position: [0.0, 0.0, 0.0, 1.0],
        rotation: [0.89, 0.0, -0.432, 0.144],
        scale_opacity: [10.0, 1.0, 1.0, 0.5],
        spherical_harmonic: blue_sh,
    };

    commands.spawn((
        GaussianSplattingBundle {
            cloud: gaussian_assets.add(
                GaussianCloud(vec![
                    blue_aabb_gaussian,
                    blue_aabb_gaussian,
                ])
            ),
            settings: GaussianCloudSettings {
                aabb: true,
                visualize_bounding_box: true,
                ..default()
            },
            ..default()
        },
        Name::new("gaussian_cloud_aabb"),
    ));

    let mut red_sh = SphericalHarmonicCoefficients::default();
    red_sh.coefficients[0] = 5.0;

    let red_obb_gaussian = Gaussian {
        position: [0.0, 0.0, 0.0, 1.0],
        rotation: [0.89, 0.0, -0.432, 0.144],
        scale_opacity: [10.0, 1.0, 1.0, 0.5],
        spherical_harmonic: red_sh,
    };

    commands.spawn((
        GaussianSplattingBundle {
            cloud: gaussian_assets.add(
                GaussianCloud(vec![
                    red_obb_gaussian,
                    red_obb_gaussian,
                ])
            ),
            settings: GaussianCloudSettings {
                aabb: false,
                visualize_bounding_box: true,
                ..default()
            },
            ..default()
        },
        Name::new("gaussian_cloud_obb"),
    ));

    commands.spawn((
        Camera3dBundle {
            transform: Transform::from_translation(Vec3::new(0.0, 1.5, 5.0)),
            ..default()
        },
        PanOrbitCamera{
            allow_upside_down: true,
            ..default()
        },
    ));
}

fn compare_aabb_obb_app() {
    let config = GaussianSplattingViewer::default();
    let mut app = App::new();

    // setup for gaussian viewer app
    app.insert_resource(ClearColor(Color::rgb_u8(0, 0, 0)));
    app.add_plugins(
        DefaultPlugins
        .set(ImagePlugin::default_nearest())
        .set(WindowPlugin {
            primary_window: Some(Window {
                fit_canvas_to_parent: false,
                mode: bevy::window::WindowMode::Windowed,
                present_mode: bevy::window::PresentMode::AutoVsync,
                prevent_default_event_handling: false,
                resolution: (config.width, config.height).into(),
                title: config.name.clone(),
                ..default()
            }),
            ..default()
        }),
    );
    app.add_plugins((
        PanOrbitCameraPlugin,
    ));

    if config.editor {
        app.add_plugins(WorldInspectorPlugin::new());
    }

    if config.esc_close {
        app.add_systems(Update, esc_close);
    }

    // setup for gaussian splatting
    app.add_plugins(GaussianSplattingPlugin);
    app.add_systems(Startup, setup_aabb_obb_compare);

    app.run();
}

pub fn esc_close(
    keys: Res<Input<KeyCode>>,
    mut exit: EventWriter<AppExit>
) {
    if keys.just_pressed(KeyCode::Escape) {
        exit.send(AppExit);
    }
}

pub fn main() {
    setup_hooks();
    compare_aabb_obb_app();
}

temporal sort

https://api.github.com/mosure/bevy_gaussian_splatting/blob/eedd27e0f32bdf33f324238284c915dc6e1574d2/src/render/mod.rs#L1109

            None => return RenderCommandResult::Failure,
        };

        pass.set_bind_group(2, &bind_groups.cloud_bind_group, &[]);
        pass.set_bind_group(3, &bind_groups.sorted_bind_group, &[]);

        pass.draw_indirect(&gpu_gaussian_cloud.draw_indirect_buffer, 0);

        RenderCommandResult::Success
    }
}





struct RadixSortNode {
    gaussian_clouds: QueryState<(
        &'static Handle<GaussianCloud>,
        &'static GaussianCloudBindGroup
    )>,
    initialized: bool,
    pipeline_idx: Option<u32>,
    view_bind_group: QueryState<(
        &'static GaussianViewBindGroup,
        &'static ViewUniformOffset,
    )>,
}

impl FromWorld for RadixSortNode {
    fn from_world(world: &mut World) -> Self {
        Self {
            gaussian_clouds: world.query(),
            initialized: false,
            pipeline_idx: None,
            view_bind_group: world.query(),
        }
    }
}

impl render_graph::Node for RadixSortNode {
    fn update(&mut self, world: &mut World) {
        let pipeline = world.resource::<GaussianCloudPipeline>();
        let pipeline_cache = world.resource::<PipelineCache>();

        if !self.initialized {
            let mut pipelines_loaded = true;
            for sort_pipeline in pipeline.radix_sort_pipelines.iter() {
                if let CachedPipelineState::Ok(_) =
                        pipeline_cache.get_compute_pipeline_state(*sort_pipeline)
                {
                    continue;
                }

                pipelines_loaded = false;
            }

            self.initialized = pipelines_loaded;

            if !self.initialized {
                return;
            }
        }

        if self.pipeline_idx.is_none() {
            self.pipeline_idx = Some(0);
        } else {
            self.pipeline_idx = Some((self.pipeline_idx.unwrap() + 1) % pipeline.radix_sort_pipelines.len() as u32);
        }

        self.gaussian_clouds.update_archetypes(world);
        self.view_bind_group.update_archetypes(world);
    }

    fn run(
        &self,
        _graph: &mut render_graph::RenderGraphContext,
        render_context: &mut RenderContext,
        world: &World,
    ) -> Result<(), render_graph::NodeRunError> {
        if !self.initialized || self.pipeline_idx.is_none() {
            return Ok(());
        }

        let _idx = self.pipeline_idx.unwrap() as usize; // TODO: temporal sort

        let pipeline_cache = world.resource::<PipelineCache>();
        let pipeline = world.resource::<GaussianCloudPipeline>();
        let gaussian_uniforms = world.resource::<GaussianUniformBindGroups>();

        let command_encoder = render_context.command_encoder();

        for (
            view_bind_group,
            view_uniform_offset,
        ) in self.view_bind_group.iter_manual(world) {
            for (
                cloud_handle,
                cloud_bind_group
            ) in self.gaussian_clouds.iter_manual(world) {
                let cloud = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap().get(cloud_handle).unwrap();

                let radix_digit_places = ShaderDefines::default().radix_digit_places;

                command_encoder.clear_buffer(
                    &cloud.sorting_global_buffer,
                    0,
                    None,
                );

                {
                    let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

                    // TODO: view/global
                    pass.set_bind_group(
                        0,
                        &view_bind_group.value,
                        &[view_uniform_offset.offset],
                    );
                    pass.set_bind_group(
                        1,
                        gaussian_uniforms.base_bind_group.as_ref().unwrap(),
                        &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
                    );
                    pass.set_bind_group(
                        2,
                        &cloud_bind_group.cloud_bind_group,
                        &[]
                    );
                    pass.set_bind_group(
                        3,
                        &cloud_bind_group.radix_sort_bind_groups[1],
                        &[],
                    );

                    let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap();
                    pass.set_pipeline(radix_sort_a);

                    let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a;
                    pass.dispatch_workgroups((cloud.count + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);


                    let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
                    pass.set_pipeline(radix_sort_b);

                    pass.dispatch_workgroups(1, radix_digit_places, 1);
                }

                for pass_idx in 0..radix_digit_places {
                    if pass_idx > 0 {
                        let size = ShaderDefines::default().radix_base * ShaderDefines::default().max_tile_count_c * std::mem::size_of::<u32>() as u32;
                        command_encoder.clear_buffer(
                            &cloud.sorting_global_buffer,
                            0,
                            std::num::NonZeroU64::new(size as u64).unwrap().into()
                        );
                    }

                    let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

                    let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap();
                    pass.set_pipeline(&radix_sort_c);

                    pass.set_bind_group(
                        0,
                        &view_bind_group.value,
                        &[view_uniform_offset.offset],
                    );
                    pass.set_bind_group(
                        1,
                        gaussian_uniforms.base_bind_group.as_ref().unwrap(),
                        &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
                    );
                    pass.set_bind_group(
                        2,
                        &cloud_bind_group.cloud_bind_group,
                        &[]
                    );
                    pass.set_bind_group(
                        3,
                        &cloud_bind_group.radix_sort_bind_groups[pass_idx as usize],
                        &[],
                    );

                    let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c;
                    pass.dispatch_workgroups(1, (cloud.count + workgroup_entries_c - 1) / workgroup_entries_c, 1);
                }
            }
        }


        Ok(())
    }
}

move sort to render world, use extracted views and update the existing buffer in...

// TODO: move sort to render world, use extracted views and update the existing buffer instead of creating new

use bevy::{
    prelude::*,
    asset::LoadState,
    utils::Instant,
};

use rayon::prelude::*;

use crate::{
    GaussianCloud,
    GaussianCloudSettings,
    sort::{
        SortedEntries,
        SortMode,
    },
};


#[derive(Default)]
pub struct RayonSortPlugin;

impl Plugin for RayonSortPlugin {
    fn build(&self, app: &mut App) {
        app.add_systems(Update, rayon_sort);
    }
}

pub fn rayon_sort(
    asset_server: Res<AssetServer>,
    gaussian_clouds_res: Res<Assets<GaussianCloud>>,
    mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
    gaussian_clouds: Query<(
        &Handle<GaussianCloud>,
        &Handle<SortedEntries>,
        &GaussianCloudSettings,
    )>,
    cameras: Query<(
        &GlobalTransform,
        &Camera3d,
    )>,
    mut last_camera_position: Local<Vec3>,
    mut last_sort_time: Local<Option<Instant>>,
) {
    let period = std::time::Duration::from_millis(100);
    if let Some(last_sort_time) = last_sort_time.as_ref() {
        if last_sort_time.elapsed() < period {
            return;
        }
    }

    // TODO: move sort to render world, use extracted views and update the existing buffer instead of creating new

    for (
        camera_transform,
        _camera,
    ) in cameras.iter() {
        let camera_position = camera_transform.compute_transform().translation;
        if *last_camera_position == camera_position {
            return;
        }

        for (
            gaussian_cloud_handle,
            sorted_entries_handle,
            settings,
        ) in gaussian_clouds.iter() {
            if settings.sort_mode != SortMode::Rayon {
                continue;
            }

            if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) {
                continue;
            }

            if Some(LoadState::Loading) == asset_server.get_load_state(sorted_entries_handle) {
                continue;
            }

            if let Some(gaussian_cloud) = gaussian_clouds_res.get(gaussian_cloud_handle) {
                if let Some(sorted_entries) = sorted_entries_res.get_mut(sorted_entries_handle) {
                    assert_eq!(gaussian_cloud.gaussians.len(), sorted_entries.sorted.len());

                    *last_camera_position = camera_position;
                    *last_sort_time = Some(Instant::now());

                    gaussian_cloud.gaussians.par_iter()
                        .zip(sorted_entries.sorted.par_iter_mut())
                        .enumerate()
                        .for_each(|(idx, (gaussian, sort_entry))| {
                            let position = Vec3::from_slice(gaussian.position.as_ref());
                            let delta = camera_position - position;

                            sort_entry.key = bytemuck::cast(delta.length_squared());
                            sort_entry.index = idx as u32;
                        });

                    sorted_entries.sorted.par_sort_unstable_by(|a, b| {
                        bytemuck::cast::<u32, f32>(b.key).partial_cmp(&bytemuck::cast::<u32, f32>(a.key)).unwrap()
                    });

                    // TODO: update DrawIndirect buffer during sort phase (GPU sort will override default DrawIndirect)
                }
            }
        }
    }
}

support loading from directory of images

// TODO: support loading from directory of images

use bevy::{
    prelude::*,
    asset::LoadState,
    ecs::query::QueryItem,
    render::{
        extract_component::{
            ExtractComponent,
            ExtractComponentPlugin,
        },
        Render,
        RenderApp,
        RenderSet,
        render_asset::RenderAssets,
        render_resource::{
            BindGroup,
            BindGroupLayout,
            BindGroupLayoutDescriptor,
            BindGroupLayoutEntry,
            BindGroupEntry,
            BindingType,
            BindingResource,
            Extent3d,
            TextureDimension,
            TextureFormat,
            TextureSampleType,
            TextureUsages,
            TextureViewDimension,
            ShaderStages,
        },
        renderer::RenderDevice,
    },
};
use static_assertions::assert_cfg;

#[allow(unused_imports)]
use crate::{
    gaussian::{
        cloud::GaussianCloud,
        f32::{
            PositionVisibility,
            Rotation,
            ScaleOpacity,
        },
    },
    material::spherical_harmonics::{
        SH_COEFF_COUNT,
        SH_VEC4_PLANES,
        SphericalHarmonicCoefficients,
    },
    render::{
        GaussianCloudPipeline,
        GpuGaussianCloud,
    },
};


// TODO: support loading from directory of images


assert_cfg!(
    feature = "planar",
    "texture rendering is only supported with the `planar` feature enabled",
);

assert_cfg!(
    not(feature = "f32"),
    "f32 texture support is not implemented yet",
);


#[derive(Component, Clone, Debug, Reflect)]
pub struct TextureBuffers {
    position_visibility: Handle<Image>,
    spherical_harmonics: Handle<Image>,

    #[cfg(feature = "f16")]
    rotation_scale_opacity: Handle<Image>,

    #[cfg(feature = "f32")]
    rotation: Handle<Image>,
    #[cfg(feature = "f32")]
    scale_opacity: Handle<Image>,
}

impl ExtractComponent for TextureBuffers {
    type Query = &'static Self;

    type Filter = ();
    type Out = Self;

    fn extract_component(texture_buffers: QueryItem<'_, Self::Query>) -> Option<Self::Out> {
        texture_buffers.clone().into()
    }
}


#[derive(Default)]
pub struct BufferTexturePlugin;

impl Plugin for BufferTexturePlugin {
    fn build(&self, app: &mut App) {
        app.register_type::<TextureBuffers>();
        app.add_plugins(ExtractComponentPlugin::<TextureBuffers>::default());

        app.add_systems(Update, queue_textures);

        let render_app = app.sub_app_mut(RenderApp);
        render_app.add_systems(
            Render,
            queue_gpu_texture_buffers.in_set(RenderSet::PrepareAssets),
        );
    }
}


#[derive(Component, Clone, Debug)]
pub struct GpuTextureBuffers {
    pub bind_group: BindGroup,
}

pub fn queue_gpu_texture_buffers(
    mut commands: Commands,
    // gaussian_cloud_pipeline: Res<GaussianCloudPipeline>,
    pipeline: ResMut<GaussianCloudPipeline>,
    render_device: ResMut<RenderDevice>,
    gpu_images: Res<RenderAssets<Image>>,
    clouds: Query<(
        Entity,
        &TextureBuffers,
    )>,
) {
    for (entity, texture_buffers,) in clouds.iter() {
        #[cfg(feature = "f16")]
        let bind_group = render_device.create_bind_group(
            Some("texture_gaussian_cloud_bind_group"),
            &pipeline.gaussian_cloud_layout,
            &[
                BindGroupEntry {
                    binding: 0,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.position_visibility).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 1,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.spherical_harmonics).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 2,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.rotation_scale_opacity).unwrap().texture_view
                    ),
                },
            ],
        );

        #[cfg(feature = "f32")]
        let bind_group = render_device.create_bind_group(
            Some("texture_gaussian_cloud_bind_group"),
            &pipeline.gaussian_cloud_layout,
            &[
                BindGroupEntry {
                    binding: 0,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.position_visibility).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 1,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.spherical_harmonics).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 2,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.rotation).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 3,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.scale_opacity).unwrap().texture_view
                    ),
                },
            ],
        );

        commands.entity(entity).insert(GpuTextureBuffers { bind_group });
    }
}


// TODO: support asset change detection and reupload
fn queue_textures(
    mut commands: Commands,
    asset_server: Res<AssetServer>,
    gaussian_cloud_res: Res<Assets<GaussianCloud>>,
    mut images: ResMut<Assets<Image>>,
    clouds: Query<(
        Entity,
        &Handle<GaussianCloud>,
        Without<TextureBuffers>,
    )>,
) {
    for (entity, cloud_handle, _) in clouds.iter() {
        if Some(LoadState::Loading) == asset_server.get_load_state(cloud_handle){
            continue;
        }

        if gaussian_cloud_res.get(cloud_handle).is_none() {
            continue;
        }

        let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();

        let square = cloud.len_sqrt_ceil() as u32;
        let extent_1d = Extent3d {
            width: square,
            height: square, // TODO: shrink height to save memory (consider fixed width)
            depth_or_array_layers: 1,
        };

        let mut position_visibility = Image::new(
            extent_1d,
            TextureDimension::D2,
            bytemuck::cast_slice(cloud.position_visibility.as_slice()).to_vec(),
            TextureFormat::Rgba32Float,
        );
        position_visibility.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
        let position_visibility = images.add(position_visibility);

        let texture_buffers: TextureBuffers;

        #[cfg(feature = "f16")]
        {
            let planar_spherical_harmonics: Vec<u32> = (0..SH_VEC4_PLANES)
                .flat_map(|plane_index| {
                    cloud.spherical_harmonic.iter()
                        .flat_map(move |sh| {
                            let start_index = plane_index * 4;
                            let end_index = std::cmp::min(start_index + 4, sh.coefficients.len());

                            let mut depthwise = sh.coefficients[start_index..end_index].to_vec();
                            depthwise.resize(4, 0);

                            depthwise
                        })
                })
                .collect();

            let mut spherical_harmonics = Image::new(
                Extent3d {
                    width: square,
                    height: square,
                    depth_or_array_layers: SH_VEC4_PLANES as u32,
                },
                TextureDimension::D2,
                bytemuck::cast_slice(planar_spherical_harmonics.as_slice()).to_vec(),
                TextureFormat::Rgba32Uint,
            );
            spherical_harmonics.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
            let spherical_harmonics = images.add(spherical_harmonics);

            let mut rotation_scale_opacity = Image::new(
                extent_1d,
                TextureDimension::D2,
                bytemuck::cast_slice(cloud.rotation_scale_opacity_packed128.as_slice()).to_vec(),
                TextureFormat::Rgba32Uint,
            );
            rotation_scale_opacity.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
            let rotation_scale_opacity = images.add(rotation_scale_opacity);

            texture_buffers = TextureBuffers {
                position_visibility,
                spherical_harmonics,
                rotation_scale_opacity,
            };
        }

        #[cfg(feature = "f32")]
        {
            texture_buffers = TextureBuffers {
                position_visibility,
                spherical_harmonics: todo!(),
                rotation: todo!(),
                scale_opacity: todo!(),
            };
        }

        commands.entity(entity).insert(texture_buffers);
    }
}


pub fn get_sorted_bind_group_layout(
    render_device: &RenderDevice,
) -> BindGroupLayout {
    render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
        label: Some("texture_sorted_layout"),
        entries: &[
            BindGroupLayoutEntry {
                binding: 0,
                visibility: ShaderStages::all(),
                ty: BindingType::Texture {
                    view_dimension: TextureViewDimension::D2,
                    sample_type: TextureSampleType::Uint,
                    multisampled: false,
                },
                count: None,
            },
        ],
    })
}


#[cfg(feature = "f16")]
pub fn get_bind_group_layout(
    render_device: &RenderDevice,
    _read_only: bool
) -> BindGroupLayout {
    let sh_view_dimension = if SH_VEC4_PLANES == 1 {
        TextureViewDimension::D2
    } else {
        TextureViewDimension::D2Array
    };

    render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
        label: Some("texture_f16_gaussian_cloud_layout"),
        entries: &[
            BindGroupLayoutEntry {
                binding: 0,
                visibility: ShaderStages::all(),
                ty: BindingType::Texture {
                    view_dimension: TextureViewDimension::D2,
                    sample_type: TextureSampleType::Float {
                        filterable: false,
                    },
                    multisampled: false,
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 1,
                visibility: ShaderStages::all(),
                ty: BindingType::Texture {
                    view_dimension: sh_view_dimension,
                    sample_type: TextureSampleType::Uint,
                    multisampled: false,
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 2,
                visibility: ShaderStages::all(),
                ty: BindingType::Texture {
                    view_dimension: TextureViewDimension::D2,
                    sample_type: TextureSampleType::Uint,
                    multisampled: false,
                },
                count: None,
            },
        ],
    })
}


#[cfg(feature = "f32")]
pub fn get_bind_group_layout(
    render_device: &RenderDevice,
    read_only: bool
) -> BindGroupLayout {
    render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
        label: Some("texture_f32_gaussian_cloud_layout"),
        entries: &[
            BindGroupLayoutEntry {
                binding: 0,
                visibility: ShaderStages::all(),
                ty: BindingType::Buffer {
                    ty: BufferBindingType::Storage { read_only },
                    has_dynamic_offset: false,
                    min_binding_size: BufferSize::new(std::mem::size_of::<PositionVisibility>() as u64),
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 1,
                visibility: ShaderStages::all(),
                ty: BindingType::Buffer {
                    ty: BufferBindingType::Storage { read_only },
                    has_dynamic_offset: false,
                    min_binding_size: BufferSize::new(std::mem::size_of::<SphericalHarmonicCoefficients>() as u64),
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 2,
                visibility: ShaderStages::all(),
                ty: BindingType::Buffer {
                    ty: BufferBindingType::Storage { read_only },
                    has_dynamic_offset: false,
                    min_binding_size: BufferSize::new(std::mem::size_of::<Rotation>() as u64),
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 3,
                visibility: ShaderStages::all(),
                ty: BindingType::Buffer {
                    ty: BufferBindingType::Storage { read_only },
                    has_dynamic_offset: false,
                    min_binding_size: BufferSize::new(std::mem::size_of::<ScaleOpacity>() as u64),
                },
                count: None,
            },
        ],
    })
}

allow setting shader defines via API

https://api.github.com/mosure/bevy_gaussian_splatting/blob/eedd27e0f32bdf33f324238284c915dc6e1574d2/src/render/mod.rs#L509

            ],
        });

        let sorting_buffer_entry = BindGroupLayoutEntry {
            binding: 1,
            visibility: ShaderStages::COMPUTE,
            ty: BindingType::Buffer {
                ty: BufferBindingType::Storage { read_only: false },
                has_dynamic_offset: false,
                min_binding_size: BufferSize::new(ShaderDefines::default().sorting_buffer_size as u64),
            },
            count: None,
        };

        let draw_indirect_buffer_entry = BindGroupLayoutEntry {
            binding: 2,
            visibility: ShaderStages::COMPUTE,
            ty: BindingType::Buffer {
                ty: BufferBindingType::Storage { read_only: false },
                has_dynamic_offset: false,
                min_binding_size: BufferSize::new(std::mem::size_of::<wgpu::util::DrawIndirect>() as u64),
            },
            count: None,
        };

        let radix_sort_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
            label: Some("radix_sort_layout"),
            entries: &[
                BindGroupLayoutEntry {
                    binding: 0,
                    visibility: ShaderStages::COMPUTE,
                    ty: BindingType::Buffer {
                        ty: BufferBindingType::Uniform,
                        has_dynamic_offset: false,
                        min_binding_size: BufferSize::new(std::mem::size_of::<u32>() as u64),
                    },
                    count: None,
                },
                sorting_buffer_entry,
                draw_indirect_buffer_entry,
                BindGroupLayoutEntry {
                    binding: 3,
                    visibility: ShaderStages::COMPUTE,
                    ty: BindingType::Buffer {
                        ty: BufferBindingType::Storage { read_only: false },
                        has_dynamic_offset: false,
                        min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64),
                    },
                    count: None,
                },
                BindGroupLayoutEntry {
                    binding: 4,
                    visibility: ShaderStages::COMPUTE,
                    ty: BindingType::Buffer {
                        ty: BufferBindingType::Storage { read_only: false },
                        has_dynamic_offset: false,
                        min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64),
                    },
                    count: None,
                },
            ],
        });

        let sorted_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
            label: Some("sorted_layout"),
            entries: &vec![
                BindGroupLayoutEntry {
                    binding: 5,
                    visibility: ShaderStages::VERTEX_FRAGMENT,
                    ty: BindingType::Buffer {
                        ty: BufferBindingType::Storage { read_only: true },
                        has_dynamic_offset: false,
                        min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64),
                    },
                    count: None,
                },
            ],
        });

        let compute_layout = vec![
            view_layout.clone(),
            gaussian_uniform_layout.clone(),
            gaussian_cloud_layout.clone(),
            radix_sort_layout.clone(),
        ];
        let shader = GAUSSIAN_SHADER_HANDLE.typed();
        let shader_defs = shader_defs(false, false);

        let pipeline_cache = render_world.resource::<PipelineCache>();
        let radix_sort_a = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
            label: Some("radix_sort_a".into()),
            layout: compute_layout.clone(),
            push_constant_ranges: vec![],
            shader: shader.clone(),
            shader_defs: shader_defs.clone(),
            entry_point: "radix_sort_a".into(),
        });

        let radix_sort_b = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
            label: Some("radix_sort_b".into()),
            layout: compute_layout.clone(),
            push_constant_ranges: vec![],
            shader: shader.clone(),
            shader_defs: shader_defs.clone(),
            entry_point: "radix_sort_b".into(),
        });

        let radix_sort_c = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
            label: Some("radix_sort_c".into()),
            layout: compute_layout.clone(),
            push_constant_ranges: vec![],
            shader: shader.clone(),
            shader_defs: shader_defs.clone(),
            entry_point: "radix_sort_c".into(),
        });


        let temporal_sort_flip = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
            label: Some("temporal_sort_flip".into()),
            layout: compute_layout.clone(),
            push_constant_ranges: vec![],
            shader: shader.clone(),
            shader_defs: shader_defs.clone(),
            entry_point: "temporal_sort_flip".into(),
        });

        let temporal_sort_flop = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
            label: Some("temporal_sort_flop".into()),
            layout: compute_layout.clone(),
            push_constant_ranges: vec![],
            shader: shader.clone(),
            shader_defs: shader_defs.clone(),
            entry_point: "temporal_sort_flop".into(),
        });

        GaussianCloudPipeline {
            gaussian_cloud_layout,
            gaussian_uniform_layout,
            view_layout,
            shader: shader.clone(),
            radix_sort_layout,
            radix_sort_pipelines: [
                radix_sort_a,
                radix_sort_b,
                radix_sort_c,
            ],
            temporal_sort_pipelines: [
                temporal_sort_flip,
                temporal_sort_flop,
            ],
            sorted_layout,
        }
    }
}

// TODO: allow setting shader defines via API
struct ShaderDefines {
    radix_bits_per_digit: u32,
    radix_digit_places: u32,
    radix_base: u32,
    entries_per_invocation_a: u32,
    entries_per_invocation_c: u32,
    workgroup_invocations_a: u32,
    workgroup_invocations_c: u32,
    workgroup_entries_a: u32,
    workgroup_entries_c: u32,
    max_tile_count_c: u32,
    sorting_buffer_size: usize,

    temporal_sort_window_size: u32,
}

impl Default for ShaderDefines {
    fn default() -> Self {
        let radix_bits_per_digit = 8;
        let radix_digit_places = 32 / radix_bits_per_digit;
        let radix_base = 1 << radix_bits_per_digit;
        let entries_per_invocation_a = 4;
        let entries_per_invocation_c = 4;
        let workgroup_invocations_a = radix_base * radix_digit_places;
        let workgroup_invocations_c = radix_base;
        let workgroup_entries_a = workgroup_invocations_a * entries_per_invocation_a;
        let workgroup_entries_c = workgroup_invocations_c * entries_per_invocation_c;
        let max_tile_count_c = (10000000 + workgroup_entries_c - 1) / workgroup_entries_c;
        let sorting_buffer_size = (
            radix_base as usize *
            (radix_digit_places as usize + max_tile_count_c as usize) *
            std::mem::size_of::<u32>()
        ) + std::mem::size_of::<u32>() * 5;

        Self {
            radix_bits_per_digit,
            radix_digit_places,
            radix_base,
            entries_per_invocation_a,
            entries_per_invocation_c,
            workgroup_invocations_a,
            workgroup_invocations_c,
            workgroup_entries_a,
            workgroup_entries_c,
            max_tile_count_c,
            sorting_buffer_size,

            temporal_sort_window_size: 16,
        }
    }
}

fn shader_defs(
    aabb: bool,
    visualize_bounding_box: bool,
) -> Vec<ShaderDefVal> {
    let defines = ShaderDefines::default();
    let mut shader_defs = vec![
        ShaderDefVal::UInt("MAX_SH_COEFF_COUNT".into(), MAX_SH_COEFF_COUNT as u32),
        ShaderDefVal::UInt("RADIX_BASE".into(), defines.radix_base),
        ShaderDefVal::UInt("RADIX_BITS_PER_DIGIT".into(), defines.radix_bits_per_digit),
        ShaderDefVal::UInt("RADIX_DIGIT_PLACES".into(), defines.radix_digit_places),
        ShaderDefVal::UInt("ENTRIES_PER_INVOCATION_A".into(), defines.entries_per_invocation_a),
        ShaderDefVal::UInt("ENTRIES_PER_INVOCATION_C".into(), defines.entries_per_invocation_c),
        ShaderDefVal::UInt("WORKGROUP_INVOCATIONS_A".into(), defines.workgroup_invocations_a),
        ShaderDefVal::UInt("WORKGROUP_INVOCATIONS_C".into(), defines.workgroup_invocations_c),
        ShaderDefVal::UInt("WORKGROUP_ENTRIES_C".into(), defines.workgroup_entries_c),
        ShaderDefVal::UInt("MAX_TILE_COUNT_C".into(), defines.max_tile_count_c),

        ShaderDefVal::UInt("TEMPORAL_SORT_WINDOW_SIZE".into(), defines.temporal_sort_window_size),
    ];

    if aabb {
        shader_defs.push("USE_AABB".into());
    }

    if !aabb {
        shader_defs.push("USE_OBB".into());
    }

    if visualize_bounding_box {
        shader_defs.push("VISUALIZE_BOUNDING_BOX".into());
    }

    shader_defs
}

#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub struct GaussianCloudPipelineKey {
    pub aabb: bool,

add correctness test (use CPU gaussian pipeline to compare results)

// TODO: add correctness test (use CPU gaussian pipeline to compare results)

use std::sync::{
    Arc,
    Mutex,
};

use bevy::{
    prelude::*,
    app::AppExit,
    core::FrameCount,
    render::view::screenshot::ScreenshotManager,
    window::PrimaryWindow,
};

use bevy_gaussian_splatting::{
    GaussianCloud,
    GaussianSplattingBundle,
    random_gaussians,
};

use _harness::{
    TestHarness,
    test_harness_app,
    TestStateArc,
};

mod _harness;


fn main() {
    let mut app = test_harness_app(TestHarness {
        resolution: (512.0, 512.0),
    });

    app.add_systems(Startup, setup);
    app.add_systems(Update, capture_ready);

    app.run();
}

fn setup(
    mut commands: Commands,
    mut gaussian_assets: ResMut<Assets<GaussianCloud>>,
) {
    let cloud = gaussian_assets.add(random_gaussians(10000));

    commands.spawn((
        GaussianSplattingBundle {
            cloud,
            ..default()
        },
        Name::new("gaussian_cloud"),
    ));

    commands.spawn((
        Camera3dBundle {
            transform: Transform::from_translation(Vec3::new(0.0, 1.5, 5.0)),
            ..default()
        },
    ));
}

fn check_image_equality(image: &Image, other: &Image) -> bool {
    if image.width() != other.width() || image.height() != other.height() {
        return false;
    }

    for (word, other_word) in image.data.iter().zip(other.data.iter()) {
        if word != other_word {
            return false;
        }
    }

    true
}

fn test_stability(captures: Arc<Mutex<Vec<Image>>>) {
    let all_frames_similar = captures.lock().unwrap().iter()
        .fold(Some(None), |acc, image| {
            match acc {
                Some(acc_image) => {
                    if let Some(acc_image) = acc_image {
                        if check_image_equality(acc_image, image) {
                            Some(Some(acc_image))
                        } else {
                            None
                        }
                    } else {
                        Some(Some(image))
                    }
                },
                None => None,
            }
        }).is_some();
    assert!(all_frames_similar, "all frames are not the same");
}

fn save_captures(captures: Arc<Mutex<Vec<Image>>>) {
    captures.lock().unwrap().iter()
        .enumerate()
        .for_each(|(i, image)| {
            let path = format!("target/tmp/test_gaussian_frame_{}.png", i);

            let dyn_img = image.clone().try_into_dynamic().unwrap();
            let img = dyn_img.to_rgba8();
            img.save(path).unwrap();
        });
}

fn capture_ready(
    // gaussian_cloud_assets: Res<Assets<GaussianCloud>>,
    // asset_server: Res<AssetServer>,
    // gaussian_clouds: Query<
    //     Entity,
    //     &Handle<GaussianCloud>,
    // >,
    main_window: Query<Entity, With<PrimaryWindow>>,
    mut screenshot_manager: ResMut<ScreenshotManager>,
    mut exit: EventWriter<AppExit>,
    frame_count: Res<FrameCount>,
    state: Local<TestStateArc>,
    buffer: Local<Arc<Mutex<Vec<Image>>>>,
) {
    let buffer = buffer.to_owned();

    let buffer_frames = 10;
    let wait_frames = 10;  // wait for gaussian cloud to load
    if frame_count.0 < wait_frames {
        return;
    }

    let state_clone = Arc::clone(&state);
    let buffer_clone = Arc::clone(&buffer);

    let mut state = state.lock().unwrap();
    state.test_loaded = true;

    if state.test_completed {
        {
            let captures = buffer.lock().unwrap();
            let frame_count = captures.len();
            assert_eq!(frame_count, buffer_frames, "captured {} frames, expected {}", frame_count, buffer_frames);
        }

        save_captures(buffer.clone());
        test_stability(buffer);
        // TODO: add correctness test (use CPU gaussian pipeline to compare results)

        exit.send(AppExit);
        return;
    }

    if let Ok(window_entity) = main_window.get_single() {
        screenshot_manager.take_screenshot(window_entity, move |image: Image| {
            let has_non_zero_data = image.data.iter().fold(false, |non_zero, &x| non_zero || x != 0);
            assert!(has_non_zero_data, "screenshot is all zeros");

            let mut buffer = buffer_clone.lock().unwrap();
            buffer.push(image);

            if buffer.len() >= buffer_frames {
                let mut state = state_clone.lock().unwrap();
                state.test_completed = true;
            }
        }).unwrap();
    }
}

abstract source of cloud_bind_group (e.g. packed vs. planar)

pass.set_bind_group(2, &bind_groups.cloud_bind_group, &[]); // TODO: abstract source of cloud_bind_group (e.g. packed vs. planar)

            None => return RenderCommandResult::Failure,
        };

        pass.set_bind_group(2, &bind_groups.cloud_bind_group, &[]); // TODO: abstract source of cloud_bind_group (e.g. packed vs. planar)
        pass.set_bind_group(3, &bind_groups.sorted_bind_group, &[]);

        #[cfg(feature = "webgl2")]
        pass.draw(0..4, 0..gpu_gaussian_cloud.count as u32);

        #[cfg(not(feature = "webgl2"))]
        pass.draw_indirect(&gpu_gaussian_cloud.draw_indirect_buffer, 0);

        RenderCommandResult::Success

update DrawIndirect buffer during sort phase (GPU sort will override default Dra...

// TODO: update DrawIndirect buffer during sort phase (GPU sort will override default DrawIndirect)

use bevy::{
    prelude::*,
    asset::LoadState,
    utils::Instant,
};

use rayon::prelude::*;

use crate::{
    GaussianCloud,
    GaussianCloudSettings,
    sort::{
        SortedEntries,
        SortMode,
    },
};


#[derive(Default)]
pub struct RayonSortPlugin;

impl Plugin for RayonSortPlugin {
    fn build(&self, app: &mut App) {
        app.add_systems(Update, rayon_sort);
    }
}

pub fn rayon_sort(
    asset_server: Res<AssetServer>,
    gaussian_clouds_res: Res<Assets<GaussianCloud>>,
    mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
    gaussian_clouds: Query<(
        &Handle<GaussianCloud>,
        &Handle<SortedEntries>,
        &GaussianCloudSettings,
    )>,
    cameras: Query<(
        &GlobalTransform,
        &Camera3d,
    )>,
    mut last_camera_position: Local<Vec3>,
    mut last_sort_time: Local<Option<Instant>>,
) {
    let period = std::time::Duration::from_millis(100);
    if let Some(last_sort_time) = last_sort_time.as_ref() {
        if last_sort_time.elapsed() < period {
            return;
        }
    }

    // TODO: move sort to render world, use extracted views and update the existing buffer instead of creating new

    for (
        camera_transform,
        _camera,
    ) in cameras.iter() {
        let camera_position = camera_transform.compute_transform().translation;
        if *last_camera_position == camera_position {
            return;
        }

        for (
            gaussian_cloud_handle,
            sorted_entries_handle,
            settings,
        ) in gaussian_clouds.iter() {
            if settings.sort_mode != SortMode::Rayon {
                continue;
            }

            if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) {
                continue;
            }

            if Some(LoadState::Loading) == asset_server.get_load_state(sorted_entries_handle) {
                continue;
            }

            if let Some(gaussian_cloud) = gaussian_clouds_res.get(gaussian_cloud_handle) {
                if let Some(sorted_entries) = sorted_entries_res.get_mut(sorted_entries_handle) {
                    assert_eq!(gaussian_cloud.gaussians.len(), sorted_entries.sorted.len());

                    *last_camera_position = camera_position;
                    *last_sort_time = Some(Instant::now());

                    gaussian_cloud.gaussians.par_iter()
                        .zip(sorted_entries.sorted.par_iter_mut())
                        .enumerate()
                        .for_each(|(idx, (gaussian, sort_entry))| {
                            let position = Vec3::from_slice(gaussian.position.as_ref());
                            let delta = camera_position - position;

                            sort_entry.key = bytemuck::cast(delta.length_squared());
                            sort_entry.index = idx as u32;
                        });

                    sorted_entries.sorted.par_sort_unstable_by(|a, b| {
                        bytemuck::cast::<u32, f32>(b.key).partial_cmp(&bytemuck::cast::<u32, f32>(a.key)).unwrap()
                    });

                    // TODO: update DrawIndirect buffer during sort phase (GPU sort will override default DrawIndirect)
                }
            }
        }
    }
}

(extract GaussianCloud, TextureBuffers) when feature buffer_texture is enabled

// TODO: (extract GaussianCloud, TextureBuffers) when feature buffer_texture is enabled

            usage: BufferUsages::INDIRECT | BufferUsages::COPY_DST | BufferUsages::STORAGE | BufferUsages::COPY_SRC,
        });

        // TODO: (extract GaussianCloud, TextureBuffers) when feature buffer_texture is enabled

        Ok(GpuGaussianCloud {
            count,
            draw_indirect_buffer,

            #[cfg(feature = "debug_gpu")]
            debug_gpu: gaussian_cloud,

            #[cfg(feature = "packed")]
            packed: packed::prepare_cloud(render_device, &gaussian_cloud),
            #[cfg(feature = "buffer_storage")]
            planar: planar::prepare_cloud(render_device, &gaussian_cloud),
        })
    }
}

overloaded system, move to resource setup system

// TODO: overloaded system, move to resource setup system

    asset_server: Res<AssetServer>,
    gaussian_cloud_res: Res<RenderAssets<GaussianCloud>>,
    sorted_entries_res: Res<RenderAssets<SortedEntries>>,

    #[cfg(feature = "buffer_storage")]
    gaussian_clouds: Query<(
        Entity,
        &Handle<GaussianCloud>,
        &Handle<SortedEntries>,
    )>,
    #[cfg(feature = "buffer_texture")]
    gaussian_clouds: Query<(
        Entity,
        &Handle<GaussianCloud>,
        &Handle<SortedEntries>,
        &texture::GpuTextureBuffers,
    )>,

    #[cfg(feature = "buffer_texture")]
    gpu_images: Res<RenderAssets<Image>>,
) {
    let Some(model) = gaussian_uniforms.buffer() else {
        return;
    };

    // TODO: overloaded system, move to resource setup system
    groups.base_bind_group = Some(render_device.create_bind_group(
        "gaussian_uniform_bind_group",
        &gaussian_cloud_pipeline.gaussian_uniform_layout,

prioritize mesh selection over export filter

println!("initial cloud size: {}", cloud.len());

cloud = (0..cloud.len())

.filter(|&idx| {

is_point_in_transformed_sphere(

cloud.position(idx),

)

})

.map(|idx| cloud.gaussian(idx))

.collect();

println!("filtered position cloud size: {}", cloud.len());

// TODO: prioritize mesh selection over export filter

    let file = std::fs::File::open(&filename).expect("failed to open file");
    let mut reader = std::io::BufReader::new(file);

    let mut cloud = GaussianCloud::from_gaussians(
        parse_ply(&mut reader).expect("failed to parse ply file"),
    );

    // TODO: prioritize mesh selection over export filter
    // println!("initial cloud size: {}", cloud.len());
    // cloud = (0..cloud.len())
    //     .filter(|&idx| {
    //         is_point_in_transformed_sphere(
    //             cloud.position(idx),
    //         )
    //     })
    //     .map(|idx| cloud.gaussian(idx))
    //     .collect();
    // println!("filtered position cloud size: {}", cloud.len());

    #[cfg(feature = "query_sparse")]
    {
        let sparse_selection = SparseSelect::default().select(&cloud).invert(cloud.len());

        cloud = sparse_selection.indicies.iter()
            .map(|idx| cloud.gaussian(*idx))
            .collect();
        println!("sparsity filtered cloud size: {}", cloud.len());
    }

    let base_filename = filename.split('.').next().expect("no extension").to_string();
    let gcloud_filename = base_filename + ".gcloud";

    write_gaussian_cloud_to_file(&cloud, &gcloud_filename);

    let post_encode_bytes = Byte::from_u64(std::fs::metadata(&gcloud_filename).expect("failed to get metadata").len());
    println!("output file size: {}", post_encode_bytes.get_appropriate_unit(UnitType::Decimal));
}

support asset change detection and reupload

// TODO: support asset change detection and reupload

use bevy::{
    prelude::*,
    asset::LoadState,
    ecs::query::QueryItem,
    render::{
        extract_component::{
            ExtractComponent,
            ExtractComponentPlugin,
        },
        Render,
        RenderApp,
        RenderSet,
        render_asset::RenderAssets,
        render_resource::{
            BindGroup,
            BindGroupLayout,
            BindGroupLayoutDescriptor,
            BindGroupLayoutEntry,
            BindGroupEntry,
            BindingType,
            BindingResource,
            Extent3d,
            TextureDimension,
            TextureFormat,
            TextureSampleType,
            TextureUsages,
            TextureViewDimension,
            ShaderStages,
        },
        renderer::RenderDevice,
    },
};
use static_assertions::assert_cfg;

#[allow(unused_imports)]
use crate::{
    gaussian::{
        cloud::GaussianCloud,
        f32::{
            PositionVisibility,
            Rotation,
            ScaleOpacity,
        },
    },
    material::spherical_harmonics::{
        SH_COEFF_COUNT,
        SH_VEC4_PLANES,
        SphericalHarmonicCoefficients,
    },
    render::{
        GaussianCloudPipeline,
        GpuGaussianCloud,
    },
};


// TODO: support loading from directory of images


assert_cfg!(
    feature = "planar",
    "texture rendering is only supported with the `planar` feature enabled",
);

assert_cfg!(
    not(feature = "f32"),
    "f32 texture support is not implemented yet",
);


#[derive(Component, Clone, Debug, Reflect)]
pub struct TextureBuffers {
    position_visibility: Handle<Image>,
    spherical_harmonics: Handle<Image>,

    #[cfg(feature = "f16")]
    rotation_scale_opacity: Handle<Image>,

    #[cfg(feature = "f32")]
    rotation: Handle<Image>,
    #[cfg(feature = "f32")]
    scale_opacity: Handle<Image>,
}

impl ExtractComponent for TextureBuffers {
    type Query = &'static Self;

    type Filter = ();
    type Out = Self;

    fn extract_component(texture_buffers: QueryItem<'_, Self::Query>) -> Option<Self::Out> {
        texture_buffers.clone().into()
    }
}


#[derive(Default)]
pub struct BufferTexturePlugin;

impl Plugin for BufferTexturePlugin {
    fn build(&self, app: &mut App) {
        app.register_type::<TextureBuffers>();
        app.add_plugins(ExtractComponentPlugin::<TextureBuffers>::default());

        app.add_systems(Update, queue_textures);

        let render_app = app.sub_app_mut(RenderApp);
        render_app.add_systems(
            Render,
            queue_gpu_texture_buffers.in_set(RenderSet::PrepareAssets),
        );
    }
}


#[derive(Component, Clone, Debug)]
pub struct GpuTextureBuffers {
    pub bind_group: BindGroup,
}

pub fn queue_gpu_texture_buffers(
    mut commands: Commands,
    // gaussian_cloud_pipeline: Res<GaussianCloudPipeline>,
    pipeline: ResMut<GaussianCloudPipeline>,
    render_device: ResMut<RenderDevice>,
    gpu_images: Res<RenderAssets<Image>>,
    clouds: Query<(
        Entity,
        &TextureBuffers,
    )>,
) {
    for (entity, texture_buffers,) in clouds.iter() {
        #[cfg(feature = "f16")]
        let bind_group = render_device.create_bind_group(
            Some("texture_gaussian_cloud_bind_group"),
            &pipeline.gaussian_cloud_layout,
            &[
                BindGroupEntry {
                    binding: 0,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.position_visibility).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 1,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.spherical_harmonics).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 2,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.rotation_scale_opacity).unwrap().texture_view
                    ),
                },
            ],
        );

        #[cfg(feature = "f32")]
        let bind_group = render_device.create_bind_group(
            Some("texture_gaussian_cloud_bind_group"),
            &pipeline.gaussian_cloud_layout,
            &[
                BindGroupEntry {
                    binding: 0,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.position_visibility).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 1,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.spherical_harmonics).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 2,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.rotation).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 3,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.scale_opacity).unwrap().texture_view
                    ),
                },
            ],
        );

        commands.entity(entity).insert(GpuTextureBuffers { bind_group });
    }
}


// TODO: support asset change detection and reupload
fn queue_textures(
    mut commands: Commands,
    asset_server: Res<AssetServer>,
    gaussian_cloud_res: Res<Assets<GaussianCloud>>,
    mut images: ResMut<Assets<Image>>,
    clouds: Query<(
        Entity,
        &Handle<GaussianCloud>,
        Without<TextureBuffers>,
    )>,
) {
    for (entity, cloud_handle, _) in clouds.iter() {
        if Some(LoadState::Loading) == asset_server.get_load_state(cloud_handle){
            continue;
        }

        if gaussian_cloud_res.get(cloud_handle).is_none() {
            continue;
        }

        let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();

        let square = cloud.len_sqrt_ceil() as u32;
        let extent_1d = Extent3d {
            width: square,
            height: square, // TODO: shrink height to save memory (consider fixed width)
            depth_or_array_layers: 1,
        };

        let mut position_visibility = Image::new(
            extent_1d,
            TextureDimension::D2,
            bytemuck::cast_slice(cloud.position_visibility.as_slice()).to_vec(),
            TextureFormat::Rgba32Float,
        );
        position_visibility.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
        let position_visibility = images.add(position_visibility);

        let texture_buffers: TextureBuffers;

        #[cfg(feature = "f16")]
        {
            let planar_spherical_harmonics: Vec<u32> = (0..SH_VEC4_PLANES)
                .flat_map(|plane_index| {
                    cloud.spherical_harmonic.iter()
                        .flat_map(move |sh| {
                            let start_index = plane_index * 4;
                            let end_index = std::cmp::min(start_index + 4, sh.coefficients.len());

                            let mut depthwise = sh.coefficients[start_index..end_index].to_vec();
                            depthwise.resize(4, 0);

                            depthwise
                        })
                })
                .collect();

            let mut spherical_harmonics = Image::new(
                Extent3d {
                    width: square,
                    height: square,
                    depth_or_array_layers: SH_VEC4_PLANES as u32,
                },
                TextureDimension::D2,
                bytemuck::cast_slice(planar_spherical_harmonics.as_slice()).to_vec(),
                TextureFormat::Rgba32Uint,
            );
            spherical_harmonics.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
            let spherical_harmonics = images.add(spherical_harmonics);

            let mut rotation_scale_opacity = Image::new(
                extent_1d,
                TextureDimension::D2,
                bytemuck::cast_slice(cloud.rotation_scale_opacity_packed128.as_slice()).to_vec(),
                TextureFormat::Rgba32Uint,
            );
            rotation_scale_opacity.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
            let rotation_scale_opacity = images.add(rotation_scale_opacity);

            texture_buffers = TextureBuffers {
                position_visibility,
                spherical_harmonics,
                rotation_scale_opacity,
            };
        }

        #[cfg(feature = "f32")]
        {
            texture_buffers = TextureBuffers {
                position_visibility,
                spherical_harmonics: todo!(),
                rotation: todo!(),
                scale_opacity: todo!(),
            };
        }

        commands.entity(entity).insert(texture_buffers);
    }
}


pub fn get_sorted_bind_group_layout(
    render_device: &RenderDevice,
) -> BindGroupLayout {
    render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
        label: Some("texture_sorted_layout"),
        entries: &[
            BindGroupLayoutEntry {
                binding: 0,
                visibility: ShaderStages::all(),
                ty: BindingType::Texture {
                    view_dimension: TextureViewDimension::D2,
                    sample_type: TextureSampleType::Uint,
                    multisampled: false,
                },
                count: None,
            },
        ],
    })
}


#[cfg(feature = "f16")]
pub fn get_bind_group_layout(
    render_device: &RenderDevice,
    _read_only: bool
) -> BindGroupLayout {
    let sh_view_dimension = if SH_VEC4_PLANES == 1 {
        TextureViewDimension::D2
    } else {
        TextureViewDimension::D2Array
    };

    render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
        label: Some("texture_f16_gaussian_cloud_layout"),
        entries: &[
            BindGroupLayoutEntry {
                binding: 0,
                visibility: ShaderStages::all(),
                ty: BindingType::Texture {
                    view_dimension: TextureViewDimension::D2,
                    sample_type: TextureSampleType::Float {
                        filterable: false,
                    },
                    multisampled: false,
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 1,
                visibility: ShaderStages::all(),
                ty: BindingType::Texture {
                    view_dimension: sh_view_dimension,
                    sample_type: TextureSampleType::Uint,
                    multisampled: false,
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 2,
                visibility: ShaderStages::all(),
                ty: BindingType::Texture {
                    view_dimension: TextureViewDimension::D2,
                    sample_type: TextureSampleType::Uint,
                    multisampled: false,
                },
                count: None,
            },
        ],
    })
}


#[cfg(feature = "f32")]
pub fn get_bind_group_layout(
    render_device: &RenderDevice,
    read_only: bool
) -> BindGroupLayout {
    render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
        label: Some("texture_f32_gaussian_cloud_layout"),
        entries: &[
            BindGroupLayoutEntry {
                binding: 0,
                visibility: ShaderStages::all(),
                ty: BindingType::Buffer {
                    ty: BufferBindingType::Storage { read_only },
                    has_dynamic_offset: false,
                    min_binding_size: BufferSize::new(std::mem::size_of::<PositionVisibility>() as u64),
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 1,
                visibility: ShaderStages::all(),
                ty: BindingType::Buffer {
                    ty: BufferBindingType::Storage { read_only },
                    has_dynamic_offset: false,
                    min_binding_size: BufferSize::new(std::mem::size_of::<SphericalHarmonicCoefficients>() as u64),
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 2,
                visibility: ShaderStages::all(),
                ty: BindingType::Buffer {
                    ty: BufferBindingType::Storage { read_only },
                    has_dynamic_offset: false,
                    min_binding_size: BufferSize::new(std::mem::size_of::<Rotation>() as u64),
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 3,
                visibility: ShaderStages::all(),
                ty: BindingType::Buffer {
                    ty: BufferBindingType::Storage { read_only },
                    has_dynamic_offset: false,
                    min_binding_size: BufferSize::new(std::mem::size_of::<ScaleOpacity>() as u64),
                },
                count: None,
            },
        ],
    })
}

derive sorting_buffer_size from cloud count (with possible rounding to next powe...

https://api.github.com/mosure/bevy_gaussian_splatting/blob/eedd27e0f32bdf33f324238284c915dc6e1574d2/src/render/mod.rs#L184

        gaussian_cloud: Self::ExtractedAsset,
        render_device: &mut SystemParamItem<Self::Param>,
    ) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
        let gaussian_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
            label: Some("gaussian cloud buffer"),
            contents: bytemuck::cast_slice(gaussian_cloud.0.as_slice()),
            usage: BufferUsages::VERTEX | BufferUsages::COPY_DST | BufferUsages::STORAGE,
        });

        let count = gaussian_cloud.0.len() as u32;

        // TODO: derive sorting_buffer_size from cloud count (with possible rounding to next power of 2)
        let sorting_global_buffer = render_device.create_buffer(&BufferDescriptor {
            label: Some("sorting global buffer"),
            size: ShaderDefines::default().sorting_buffer_size as u64,
            usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
            mapped_at_creation: false,
        });

        let draw_indirect_buffer = render_device.create_buffer(&BufferDescriptor {
            label: Some("draw indirect buffer"),
            size: std::mem::size_of::<wgpu::util::DrawIndirect>() as u64,
            usage: BufferUsages::INDIRECT | BufferUsages::COPY_DST | BufferUsages::STORAGE | BufferUsages::COPY_SRC,
            mapped_at_creation: false,
        });

        let sorting_pass_buffers = (0..4)
            .map(|idx| {
                render_device.create_buffer_with_data(&BufferInitDescriptor {
                    label: format!("sorting pass buffer {}", idx).as_str().into(),
                    contents: &[idx as u8, 0, 0, 0],
                    usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
                })
            })
            .collect::<Vec<Buffer>>()
            .try_into()
            .unwrap();

        let entry_buffer_a = render_device.create_buffer(&BufferDescriptor {
            label: Some("entry buffer a"),
            size: (count as usize * std::mem::size_of::<(u32, u32)>()) as u64,
            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
            mapped_at_creation: false,
        });

        let entry_buffer_b = render_device.create_buffer(&BufferDescriptor {
            label: Some("entry buffer b"),
            size: (count as usize * std::mem::size_of::<(u32, u32)>()) as u64,
            usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
            mapped_at_creation: false,
        });

        Ok(GpuGaussianCloud {
            gaussian_buffer,
            count,
            draw_indirect_buffer,
            sorting_global_buffer,
            sorting_pass_buffers,
            entry_buffer_a,
            entry_buffer_b,
        })
    }
}

move gaussian_cloud and sorted_entry assets into an asset bundle

// TODO: move gaussian_cloud and sorted_entry assets into an asset bundle

use bevy::{
    prelude::*,
    asset::LoadState,
    ecs::system::{
        lifetimeless::SRes,
        SystemParamItem,
    },
    reflect::TypeUuid,
    render::{
        render_asset::{
            RenderAsset,
            RenderAssetPlugin,
            PrepareAssetError,
        },
        render_resource::*,
        renderer::RenderDevice,
    },
};
use bytemuck::{
    Pod,
    Zeroable,
};
use static_assertions::assert_cfg;

use crate::{
    GaussianCloud,
    GaussianCloudSettings,
};


#[cfg(feature = "sort_radix")]
pub mod radix;

#[cfg(feature = "sort_rayon")]
pub mod rayon;


assert_cfg!(
    any(
        feature = "sort_radix",
        feature = "sort_rayon",
    ),
    "no sort mode enabled",
);


#[derive(
    Component,
    Debug,
    Clone,
    PartialEq,
    Reflect,
)]
pub enum SortMode {
    None,

    #[cfg(feature = "sort_radix")]
    Radix,

    #[cfg(feature = "sort_rayon")]
    Rayon,
}

impl Default for SortMode {
    #[allow(unreachable_code)]
    fn default() -> Self {
        #[cfg(feature = "sort_rayon")]
        return Self::Rayon;

        #[cfg(feature = "sort_radix")]
        return Self::Radix;

        Self::None
    }
}


#[derive(Default)]
pub struct SortPlugin;

impl Plugin for SortPlugin {
    fn build(&self, app: &mut App) {
        #[cfg(feature = "sort_radix")]
        app.add_plugins(radix::RadixSortPlugin);

        #[cfg(feature = "sort_rayon")]
        app.add_plugins(rayon::RayonSortPlugin);


        app.register_type::<SortedEntries>();
        app.init_asset::<SortedEntries>();
        app.register_asset_reflect::<SortedEntries>();

        app.add_plugins(RenderAssetPlugin::<SortedEntries>::default());

        app.add_systems(Update, auto_insert_sorted_entries);
    }
}


fn auto_insert_sorted_entries(
    mut commands: Commands,
    asset_server: Res<AssetServer>,
    gaussian_clouds_res: Res<Assets<GaussianCloud>>,
    mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
    gaussian_clouds: Query<(
        Entity,
        &Handle<GaussianCloud>,
        &GaussianCloudSettings,
        Without<Handle<SortedEntries>>,
    )>,
) {
    for (
        entity,
        gaussian_cloud_handle,
        _settings,
        _,
    ) in gaussian_clouds.iter() {
        // // TODO: specialize vertex shader for sort mode (e.g. draw_indirect but no sort indirection)
        // if settings.sort_mode == SortMode::None {
        //     continue;
        // }

        if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) {
            continue;
        }

        let cloud = gaussian_clouds_res.get(gaussian_cloud_handle);
        if cloud.is_none() {
            continue;
        }
        let cloud = cloud.unwrap();

        // TODO: move gaussian_cloud and sorted_entry assets into an asset bundle
        let sorted_entries = sorted_entries_res.add(SortedEntries {
            sorted: (0..cloud.gaussians.len())
                .map(|idx| {
                    SortEntry {
                        key: 1,
                        index: idx as u32,
                    }
                })
                .collect(),
        });

        commands.entity(entity)
            .insert(sorted_entries);
    }
}


#[derive(
    Clone,
    Copy,
    Debug,
    Default,
    PartialEq,
    Reflect,
    ShaderType,
    Pod,
    Zeroable,
)]
#[repr(C)]
pub struct SortEntry {
    pub key: u32,
    pub index: u32,
}

// TODO: add RenderAssetPlugin for SortedEntries & auto-insert to GaussianCloudBundles if their sort mode is not None
// supports pre-sorting or CPU sorting in main world, initializes the sorting_entry_buffer
#[derive(
    Clone,
    Asset,
    Debug,
    Default,
    PartialEq,
    Reflect,
    TypeUuid,
)]
#[uuid = "ac2f08eb-fa13-ccdd-ea11-51571ea332d5"]
pub struct SortedEntries {
    pub sorted: Vec<SortEntry>,
}

impl RenderAsset for SortedEntries {
    type ExtractedAsset = SortedEntries;
    type PreparedAsset = GpuSortedEntry;
    type Param = SRes<RenderDevice>;

    fn extract_asset(&self) -> Self::ExtractedAsset {
        self.clone()
    }

    fn prepare_asset(
        sorted_entries: Self::ExtractedAsset,
        render_device: &mut SystemParamItem<Self::Param>,
    ) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
        let sorted_entry_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
            label: Some("sorted_entry_buffer"),
            contents: bytemuck::cast_slice(sorted_entries.sorted.as_slice()),
            usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST | BufferUsages::STORAGE,
        });

        let count = sorted_entries.sorted.len();

        Ok(GpuSortedEntry {
            sorted_entry_buffer,
            count,
        })
    }
}


// TODO: support instancing and multiple cameras
//       separate entry_buffer_a binding into unique a bind group to optimize buffer updates
#[derive(Debug, Clone)]
pub struct GpuSortedEntry {
    pub sorted_entry_buffer: Buffer,
    pub count: usize,
}

specialize vertex shader for sort mode (e.g. draw_indirect but no sort indirecti...

if settings.sort_mode == SortMode::None {

continue;

}

// // TODO: specialize vertex shader for sort mode (e.g. draw_indirect but no sort indirection)

use bevy::{
    prelude::*,
    asset::LoadState,
    ecs::system::{
        lifetimeless::SRes,
        SystemParamItem,
    },
    reflect::TypeUuid,
    render::{
        render_asset::{
            RenderAsset,
            RenderAssetPlugin,
            PrepareAssetError,
        },
        render_resource::*,
        renderer::RenderDevice,
    },
};
use bytemuck::{
    Pod,
    Zeroable,
};
use static_assertions::assert_cfg;

use crate::{
    GaussianCloud,
    GaussianCloudSettings,
};


#[cfg(feature = "sort_radix")]
pub mod radix;

#[cfg(feature = "sort_rayon")]
pub mod rayon;


assert_cfg!(
    any(
        feature = "sort_radix",
        feature = "sort_rayon",
    ),
    "no sort mode enabled",
);


#[derive(
    Component,
    Debug,
    Clone,
    PartialEq,
    Reflect,
)]
pub enum SortMode {
    None,

    #[cfg(feature = "sort_radix")]
    Radix,

    #[cfg(feature = "sort_rayon")]
    Rayon,
}

impl Default for SortMode {
    #[allow(unreachable_code)]
    fn default() -> Self {
        #[cfg(feature = "sort_rayon")]
        return Self::Rayon;

        #[cfg(feature = "sort_radix")]
        return Self::Radix;

        Self::None
    }
}


#[derive(Default)]
pub struct SortPlugin;

impl Plugin for SortPlugin {
    fn build(&self, app: &mut App) {
        #[cfg(feature = "sort_radix")]
        app.add_plugins(radix::RadixSortPlugin);

        #[cfg(feature = "sort_rayon")]
        app.add_plugins(rayon::RayonSortPlugin);


        app.register_type::<SortedEntries>();
        app.init_asset::<SortedEntries>();
        app.register_asset_reflect::<SortedEntries>();

        app.add_plugins(RenderAssetPlugin::<SortedEntries>::default());

        app.add_systems(Update, auto_insert_sorted_entries);
    }
}


fn auto_insert_sorted_entries(
    mut commands: Commands,
    asset_server: Res<AssetServer>,
    gaussian_clouds_res: Res<Assets<GaussianCloud>>,
    mut sorted_entries_res: ResMut<Assets<SortedEntries>>,
    gaussian_clouds: Query<(
        Entity,
        &Handle<GaussianCloud>,
        &GaussianCloudSettings,
        Without<Handle<SortedEntries>>,
    )>,
) {
    for (
        entity,
        gaussian_cloud_handle,
        _settings,
        _,
    ) in gaussian_clouds.iter() {
        // // TODO: specialize vertex shader for sort mode (e.g. draw_indirect but no sort indirection)
        // if settings.sort_mode == SortMode::None {
        //     continue;
        // }

        if Some(LoadState::Loading) == asset_server.get_load_state(gaussian_cloud_handle) {
            continue;
        }

        let cloud = gaussian_clouds_res.get(gaussian_cloud_handle);
        if cloud.is_none() {
            continue;
        }
        let cloud = cloud.unwrap();

        // TODO: move gaussian_cloud and sorted_entry assets into an asset bundle
        let sorted_entries = sorted_entries_res.add(SortedEntries {
            sorted: (0..cloud.gaussians.len())
                .map(|idx| {
                    SortEntry {
                        key: 1,
                        index: idx as u32,
                    }
                })
                .collect(),
        });

        commands.entity(entity)
            .insert(sorted_entries);
    }
}


#[derive(
    Clone,
    Copy,
    Debug,
    Default,
    PartialEq,
    Reflect,
    ShaderType,
    Pod,
    Zeroable,
)]
#[repr(C)]
pub struct SortEntry {
    pub key: u32,
    pub index: u32,
}

// TODO: add RenderAssetPlugin for SortedEntries & auto-insert to GaussianCloudBundles if their sort mode is not None
// supports pre-sorting or CPU sorting in main world, initializes the sorting_entry_buffer
#[derive(
    Clone,
    Asset,
    Debug,
    Default,
    PartialEq,
    Reflect,
    TypeUuid,
)]
#[uuid = "ac2f08eb-fa13-ccdd-ea11-51571ea332d5"]
pub struct SortedEntries {
    pub sorted: Vec<SortEntry>,
}

impl RenderAsset for SortedEntries {
    type ExtractedAsset = SortedEntries;
    type PreparedAsset = GpuSortedEntry;
    type Param = SRes<RenderDevice>;

    fn extract_asset(&self) -> Self::ExtractedAsset {
        self.clone()
    }

    fn prepare_asset(
        sorted_entries: Self::ExtractedAsset,
        render_device: &mut SystemParamItem<Self::Param>,
    ) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
        let sorted_entry_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
            label: Some("sorted_entry_buffer"),
            contents: bytemuck::cast_slice(sorted_entries.sorted.as_slice()),
            usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST | BufferUsages::STORAGE,
        });

        let count = sorted_entries.sorted.len();

        Ok(GpuSortedEntry {
            sorted_entry_buffer,
            count,
        })
    }
}


// TODO: support instancing and multiple cameras
//       separate entry_buffer_a binding into unique a bind group to optimize buffer updates
#[derive(Debug, Clone)]
pub struct GpuSortedEntry {
    pub sorted_entry_buffer: Buffer,
    pub count: usize,
}

shrink height to save memory (consider fixed width)

height: square, // TODO: shrink height to save memory (consider fixed width)

use bevy::{
    prelude::*,
    asset::LoadState,
    ecs::query::QueryItem,
    render::{
        extract_component::{
            ExtractComponent,
            ExtractComponentPlugin,
        },
        Render,
        RenderApp,
        RenderSet,
        render_asset::RenderAssets,
        render_resource::{
            BindGroup,
            BindGroupLayout,
            BindGroupLayoutDescriptor,
            BindGroupLayoutEntry,
            BindGroupEntry,
            BindingType,
            BindingResource,
            Extent3d,
            TextureDimension,
            TextureFormat,
            TextureSampleType,
            TextureUsages,
            TextureViewDimension,
            ShaderStages,
        },
        renderer::RenderDevice,
    },
};
use static_assertions::assert_cfg;

#[allow(unused_imports)]
use crate::{
    gaussian::{
        cloud::GaussianCloud,
        f32::{
            PositionVisibility,
            Rotation,
            ScaleOpacity,
        },
    },
    material::spherical_harmonics::{
        SH_COEFF_COUNT,
        SH_VEC4_PLANES,
        SphericalHarmonicCoefficients,
    },
    render::{
        GaussianCloudPipeline,
        GpuGaussianCloud,
    },
};


// TODO: support loading from directory of images


assert_cfg!(
    feature = "planar",
    "texture rendering is only supported with the `planar` feature enabled",
);

assert_cfg!(
    not(feature = "f32"),
    "f32 texture support is not implemented yet",
);


#[derive(Component, Clone, Debug, Reflect)]
pub struct TextureBuffers {
    position_visibility: Handle<Image>,
    spherical_harmonics: Handle<Image>,

    #[cfg(feature = "f16")]
    rotation_scale_opacity: Handle<Image>,

    #[cfg(feature = "f32")]
    rotation: Handle<Image>,
    #[cfg(feature = "f32")]
    scale_opacity: Handle<Image>,
}

impl ExtractComponent for TextureBuffers {
    type Query = &'static Self;

    type Filter = ();
    type Out = Self;

    fn extract_component(texture_buffers: QueryItem<'_, Self::Query>) -> Option<Self::Out> {
        texture_buffers.clone().into()
    }
}


#[derive(Default)]
pub struct BufferTexturePlugin;

impl Plugin for BufferTexturePlugin {
    fn build(&self, app: &mut App) {
        app.register_type::<TextureBuffers>();
        app.add_plugins(ExtractComponentPlugin::<TextureBuffers>::default());

        app.add_systems(Update, queue_textures);

        let render_app = app.sub_app_mut(RenderApp);
        render_app.add_systems(
            Render,
            queue_gpu_texture_buffers.in_set(RenderSet::PrepareAssets),
        );
    }
}


#[derive(Component, Clone, Debug)]
pub struct GpuTextureBuffers {
    pub bind_group: BindGroup,
}

pub fn queue_gpu_texture_buffers(
    mut commands: Commands,
    // gaussian_cloud_pipeline: Res<GaussianCloudPipeline>,
    pipeline: ResMut<GaussianCloudPipeline>,
    render_device: ResMut<RenderDevice>,
    gpu_images: Res<RenderAssets<Image>>,
    clouds: Query<(
        Entity,
        &TextureBuffers,
    )>,
) {
    for (entity, texture_buffers,) in clouds.iter() {
        #[cfg(feature = "f16")]
        let bind_group = render_device.create_bind_group(
            Some("texture_gaussian_cloud_bind_group"),
            &pipeline.gaussian_cloud_layout,
            &[
                BindGroupEntry {
                    binding: 0,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.position_visibility).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 1,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.spherical_harmonics).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 2,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.rotation_scale_opacity).unwrap().texture_view
                    ),
                },
            ],
        );

        #[cfg(feature = "f32")]
        let bind_group = render_device.create_bind_group(
            Some("texture_gaussian_cloud_bind_group"),
            &pipeline.gaussian_cloud_layout,
            &[
                BindGroupEntry {
                    binding: 0,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.position_visibility).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 1,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.spherical_harmonics).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 2,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.rotation).unwrap().texture_view
                    ),
                },
                BindGroupEntry {
                    binding: 3,
                    resource: BindingResource::TextureView(
                        &gpu_images.get(&texture_buffers.scale_opacity).unwrap().texture_view
                    ),
                },
            ],
        );

        commands.entity(entity).insert(GpuTextureBuffers { bind_group });
    }
}


// TODO: support asset change detection and reupload
fn queue_textures(
    mut commands: Commands,
    asset_server: Res<AssetServer>,
    gaussian_cloud_res: Res<Assets<GaussianCloud>>,
    mut images: ResMut<Assets<Image>>,
    clouds: Query<(
        Entity,
        &Handle<GaussianCloud>,
        Without<TextureBuffers>,
    )>,
) {
    for (entity, cloud_handle, _) in clouds.iter() {
        if Some(LoadState::Loading) == asset_server.get_load_state(cloud_handle){
            continue;
        }

        if gaussian_cloud_res.get(cloud_handle).is_none() {
            continue;
        }

        let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();

        let square = cloud.len_sqrt_ceil() as u32;
        let extent_1d = Extent3d {
            width: square,
            height: square, // TODO: shrink height to save memory (consider fixed width)
            depth_or_array_layers: 1,
        };

        let mut position_visibility = Image::new(
            extent_1d,
            TextureDimension::D2,
            bytemuck::cast_slice(cloud.position_visibility.as_slice()).to_vec(),
            TextureFormat::Rgba32Float,
        );
        position_visibility.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
        let position_visibility = images.add(position_visibility);

        let texture_buffers: TextureBuffers;

        #[cfg(feature = "f16")]
        {
            let planar_spherical_harmonics: Vec<u32> = (0..SH_VEC4_PLANES)
                .flat_map(|plane_index| {
                    cloud.spherical_harmonic.iter()
                        .flat_map(move |sh| {
                            let start_index = plane_index * 4;
                            let end_index = std::cmp::min(start_index + 4, sh.coefficients.len());

                            let mut depthwise = sh.coefficients[start_index..end_index].to_vec();
                            depthwise.resize(4, 0);

                            depthwise
                        })
                })
                .collect();

            let mut spherical_harmonics = Image::new(
                Extent3d {
                    width: square,
                    height: square,
                    depth_or_array_layers: SH_VEC4_PLANES as u32,
                },
                TextureDimension::D2,
                bytemuck::cast_slice(planar_spherical_harmonics.as_slice()).to_vec(),
                TextureFormat::Rgba32Uint,
            );
            spherical_harmonics.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
            let spherical_harmonics = images.add(spherical_harmonics);

            let mut rotation_scale_opacity = Image::new(
                extent_1d,
                TextureDimension::D2,
                bytemuck::cast_slice(cloud.rotation_scale_opacity_packed128.as_slice()).to_vec(),
                TextureFormat::Rgba32Uint,
            );
            rotation_scale_opacity.texture_descriptor.usage = TextureUsages::COPY_DST | TextureUsages::TEXTURE_BINDING;
            let rotation_scale_opacity = images.add(rotation_scale_opacity);

            texture_buffers = TextureBuffers {
                position_visibility,
                spherical_harmonics,
                rotation_scale_opacity,
            };
        }

        #[cfg(feature = "f32")]
        {
            texture_buffers = TextureBuffers {
                position_visibility,
                spherical_harmonics: todo!(),
                rotation: todo!(),
                scale_opacity: todo!(),
            };
        }

        commands.entity(entity).insert(texture_buffers);
    }
}


pub fn get_sorted_bind_group_layout(
    render_device: &RenderDevice,
) -> BindGroupLayout {
    render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
        label: Some("texture_sorted_layout"),
        entries: &[
            BindGroupLayoutEntry {
                binding: 0,
                visibility: ShaderStages::all(),
                ty: BindingType::Texture {
                    view_dimension: TextureViewDimension::D2,
                    sample_type: TextureSampleType::Uint,
                    multisampled: false,
                },
                count: None,
            },
        ],
    })
}


#[cfg(feature = "f16")]
pub fn get_bind_group_layout(
    render_device: &RenderDevice,
    _read_only: bool
) -> BindGroupLayout {
    let sh_view_dimension = if SH_VEC4_PLANES == 1 {
        TextureViewDimension::D2
    } else {
        TextureViewDimension::D2Array
    };

    render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
        label: Some("texture_f16_gaussian_cloud_layout"),
        entries: &[
            BindGroupLayoutEntry {
                binding: 0,
                visibility: ShaderStages::all(),
                ty: BindingType::Texture {
                    view_dimension: TextureViewDimension::D2,
                    sample_type: TextureSampleType::Float {
                        filterable: false,
                    },
                    multisampled: false,
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 1,
                visibility: ShaderStages::all(),
                ty: BindingType::Texture {
                    view_dimension: sh_view_dimension,
                    sample_type: TextureSampleType::Uint,
                    multisampled: false,
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 2,
                visibility: ShaderStages::all(),
                ty: BindingType::Texture {
                    view_dimension: TextureViewDimension::D2,
                    sample_type: TextureSampleType::Uint,
                    multisampled: false,
                },
                count: None,
            },
        ],
    })
}


#[cfg(feature = "f32")]
pub fn get_bind_group_layout(
    render_device: &RenderDevice,
    read_only: bool
) -> BindGroupLayout {
    render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
        label: Some("texture_f32_gaussian_cloud_layout"),
        entries: &[
            BindGroupLayoutEntry {
                binding: 0,
                visibility: ShaderStages::all(),
                ty: BindingType::Buffer {
                    ty: BufferBindingType::Storage { read_only },
                    has_dynamic_offset: false,
                    min_binding_size: BufferSize::new(std::mem::size_of::<PositionVisibility>() as u64),
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 1,
                visibility: ShaderStages::all(),
                ty: BindingType::Buffer {
                    ty: BufferBindingType::Storage { read_only },
                    has_dynamic_offset: false,
                    min_binding_size: BufferSize::new(std::mem::size_of::<SphericalHarmonicCoefficients>() as u64),
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 2,
                visibility: ShaderStages::all(),
                ty: BindingType::Buffer {
                    ty: BufferBindingType::Storage { read_only },
                    has_dynamic_offset: false,
                    min_binding_size: BufferSize::new(std::mem::size_of::<Rotation>() as u64),
                },
                count: None,
            },
            BindGroupLayoutEntry {
                binding: 3,
                visibility: ShaderStages::all(),
                ty: BindingType::Buffer {
                    ty: BufferBindingType::Storage { read_only },
                    has_dynamic_offset: false,
                    min_binding_size: BufferSize::new(std::mem::size_of::<ScaleOpacity>() as u64),
                },
                count: None,
            },
        ],
    })
}

separate shader defines for each pipeline

// TODO: separate shader defines for each pipeline

            ],
        });

        GaussianCloudPipeline {
            gaussian_cloud_layout,
            gaussian_uniform_layout,
            view_layout,
            shader: GAUSSIAN_SHADER_HANDLE,
            sorted_layout,
        }
    }
}

// TODO: allow setting shader defines via API
// TODO: separate shader defines for each pipeline
struct ShaderDefines {
    radix_bits_per_digit: u32,
    radix_digit_places: u32,

allow hot reloading of GaussianCloud handle through inspector UI

https://api.github.com/mosure/bevy_gaussian_splatting/blob/50685245da00f48ce0d49fe9b05fab65cd2fb149/src/lib.rs#L34

impl Plugin for GaussianSplattingPlugin {
    fn build(&self, app: &mut App) {
        // TODO: allow hot reloading of GaussianCloud handle through inspector UI
        app.add_asset::<GaussianCloud>();
        app.init_asset_loader::<GaussianCloudLoader>();

        app.register_asset_reflect::<GaussianCloud>();
        app.register_type::<GaussianCloudSettings>();
        app.register_type::<GaussianSplattingBundle>();

        app.add_plugins((

separate sort and render pipelines into separate files

// TODO: separate sort and render pipelines into separate files

};


// TODO: separate sort and render pipelines into separate files
const BINDINGS_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(675257236);
const GAUSSIAN_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(68294581);
const RADIX_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(6234673214);
const SPHERICAL_HARMONICS_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(834667312);
const TEMPORAL_SORT_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(1634543224);
const TRANSFORM_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(734523534);

pub mod node {
    pub const RADIX_SORT: &str = "radix_sort";

depth order validation over ndc cells

// TODO: depth order validation over ndc cells

use std::{
    process::exit,
    sync::{
        Arc,
        Mutex,
    },
};

use bevy::{
    prelude::*,
    core::FrameCount,
    core_pipeline::core_3d::{
        CORE_3D,
        Transparent3d,
    },
    render::{
        RenderApp,
        renderer::{
            RenderContext,
            RenderQueue,
        },
        render_asset::RenderAssets,
        render_graph::{
            Node,
            NodeRunError,
            RenderGraphApp,
            RenderGraphContext,
        },
        render_phase::RenderPhase,
        view::ExtractedView,
    },
};

use bevy_gaussian_splatting::{
    GaussianCloud,
    GaussianSplattingBundle,
    random_gaussians,
    sort::SortedEntries,
};

use _harness::{
    TestHarness,
    test_harness_app,
    TestState,
    TestStateArc,
};

mod _harness;


pub mod node {
    pub const RADIX_SORT_TEST: &str = "radix_sort_test";
}


fn main() {
    let mut app = test_harness_app(TestHarness {
        resolution: (512.0, 512.0),
    });

    app.add_systems(Startup, setup);

    if let Ok(render_app) = app.get_sub_app_mut(RenderApp) {
        render_app
            .add_render_graph_node::<RadixTestNode>(
                CORE_3D,
                node::RADIX_SORT_TEST,
            )
            .add_render_graph_edge(
                CORE_3D,
                node::RADIX_SORT_TEST,
                 bevy::core_pipeline::core_3d::graph::node::END_MAIN_PASS,
            );
    }

    app.run();
}

fn setup(
    mut commands: Commands,
    mut gaussian_assets: ResMut<Assets<GaussianCloud>>,
) {
    let cloud = gaussian_assets.add(random_gaussians(10000));

    commands.spawn((
        GaussianSplattingBundle {
            cloud,
            settings: GaussianCloudSettings {
                sort_mode: SortMode::Radix,
                ..default()
            },
            ..default()
        },
        Name::new("gaussian_cloud"),
    ));

    commands.spawn((
        Camera3dBundle {
            transform: Transform::from_translation(Vec3::new(0.0, 1.5, 5.0)),
            ..default()
        },
    ));
}


pub struct RadixTestNode {
    gaussian_clouds: QueryState<(
        &'static Handle<GaussianCloud>,
        &'static Handle<SortedEntries>,
    )>,
    state: TestStateArc,
    views: QueryState<(
        &'static ExtractedView,
        &'static RenderPhase<Transparent3d>,
    )>,
    start_frame: u32,
}

impl FromWorld for RadixTestNode {
    fn from_world(world: &mut World) -> Self {
        Self {
            gaussian_clouds: world.query(),
            state: Arc::new(Mutex::new(TestState::default())),
            views: world.query(),
            start_frame: 0,
        }
    }
}


impl Node for RadixTestNode {
    fn update(
        &mut self,
        world: &mut World,
    ) {
        let mut state = self.state.lock().unwrap();
        if state.test_completed {
            exit(0);
        }

        if state.test_loaded && self.start_frame == 0 {
            self.start_frame = world.get_resource::<FrameCount>().unwrap().0;
        }

        let frame_count = world.get_resource::<FrameCount>().unwrap().0;
        const FRAME_LIMIT: u32 = 10;
        if state.test_loaded && frame_count >= self.start_frame + FRAME_LIMIT {
            state.test_completed = true;
        }

        self.gaussian_clouds.update_archetypes(world);
        self.views.update_archetypes(world);
    }

    fn run(
        &self,
        _graph: &mut RenderGraphContext,
        render_context: &mut RenderContext,
        world: &World,
    ) -> Result<(), NodeRunError> {
        for (view, _phase,) in self.views.iter_manual(world) {
            let camera_position = view.transform.translation();

            for (
                cloud_handle,
                sorted_entries_handle,
            ) in self.gaussian_clouds.iter_manual(world) {
                let gaussian_cloud_res = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap();
                let sorted_entries_res = world.get_resource::<RenderAssets<SortedEntries>>().unwrap();

                let mut state = self.state.lock().unwrap();
                if gaussian_cloud_res.get(cloud_handle).is_none() || sorted_entries_res.get(sorted_entries_handle).is_none() {
                    continue;
                } else if !state.test_loaded {
                    state.test_loaded = true;
                }

                let cloud = gaussian_cloud_res.get(cloud_handle).unwrap();
                let sorted_entries = sorted_entries_res.get(sorted_entries_handle).unwrap();
                let gaussians = cloud.debug_gpu.gaussians.clone();

                wgpu::util::DownloadBuffer::read_buffer(
                    render_context.render_device().wgpu_device(),
                    world.get_resource::<RenderQueue>().unwrap().0.as_ref(),
                    &sorted_entries.sorted_entry_buffer.slice(
                        0..sorted_entries.sorted_entry_buffer.size()
                    ),
                    move |buffer: Result<wgpu::util::DownloadBuffer, wgpu::BufferAsyncError>| {
                        let binding = buffer.unwrap();
                        let u32_muck = bytemuck::cast_slice::<u8, u32>(&*binding);

                        let mut radix_sorted_indices = Vec::new();
                        for i in (1..u32_muck.len()).step_by(2) {
                            radix_sorted_indices.push((i, u32_muck[i] as usize));
                        }

                        // TODO: depth order validation over ndc cells

                        radix_sorted_indices.iter()
                            .fold(0.0, |depth_acc, &(entry_idx, idx)| {
                                if idx == 0 || u32_muck[entry_idx - 1] == 0xffffffff || u32_muck[entry_idx - 1] == 0x0 {
                                    return depth_acc;
                                }

                                let position = gaussians[idx].position;
                                let position_vec3 = Vec3::new(position[0], position[1], position[2]);
                                let depth = (position_vec3 - camera_position).length();

                                let depth_is_non_decreasing = depth_acc <= depth;
                                if !depth_is_non_decreasing {
                                    println!(
                                        "radix keys: [..., {:#010x}, {:#010x}, {:#010x}, ...]",
                                        u32_muck[entry_idx - 1 - 2],
                                        u32_muck[entry_idx - 1],
                                        u32_muck[entry_idx - 1 + 2],
                                    );
                                }

                                assert!(depth_is_non_decreasing, "radix sort, non-decreasing check failed: {} > {}", depth_acc, depth);

                                depth_acc.max(depth)
                            });
                    }
                );
            }
        }

        Ok(())
    }
}

distance to gaussian cloud centroid

let rangefinder = view.rangefinder3d();

.distance_translation(&mesh_instance.transforms.transform.translation),

https://api.github.com/mosure/bevy_gaussian_splatting/blob/51bdd5d4efa484d388a0fdbe1d9684c06746c418/src/render/mod.rs#L271

                let pipeline = pipelines.specialize(&pipeline_cache, &custom_pipeline, key);

                // // TODO: distance to gaussian cloud centroid
                // let rangefinder = view.rangefinder3d();

                transparent_phase.add(Transparent3d {
                    entity,
                    draw_function: draw_custom,
                    distance: 0.0,
                    // distance: rangefinder
                    //     .distance_translation(&mesh_instance.transforms.transform.translation),
                    pipeline,
                    batch_range: 0..1,
                    dynamic_offset: None,
                });
            }
        }

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.