use std::{cmp, sync::Arc}; use rand::seq::SliceRandom; use crate::hittable::{hittable_list::HittableList, HitRecord, Hittable, AABB}; use crate::ray::Ray; pub struct BVHNode { left: Arc, right: Arc, aabb: AABB, } #[derive(Clone, Copy)] enum Axis { X, Y, Z, } impl BVHNode { pub fn new(hittable_list: &HittableList, time_start: f64, time_end: f64) -> BVHNode { Self::from_objects( &hittable_list.objects, 0, hittable_list.objects.len(), time_start, time_end, ) } fn from_objects( src_objects: &Vec>, start: usize, end: usize, time_start: f64, time_end: f64, ) -> BVHNode { let mut objects = src_objects.clone(); let comparator = [ |a: &Arc, b: &Arc| { Self::box_compare(a.clone(), b.clone(), Axis::X) }, |a: &Arc, b: &Arc| { Self::box_compare(a.clone(), b.clone(), Axis::Y) }, |a: &Arc, b: &Arc| { Self::box_compare(a.clone(), b.clone(), Axis::Z) }, ] .choose(&mut rand::thread_rng()) .unwrap(); let object_span = end - start; let (left, right) = match object_span { 1 => ( objects.get(start).unwrap().clone(), objects.get(start).unwrap().clone(), ), 2 => match comparator(objects.get(start).unwrap(), objects.get(start + 1).unwrap()) { cmp::Ordering::Less => ( objects.get(start).unwrap().clone(), objects.get(start + 1).unwrap().clone(), ), _ => ( objects.get(start + 1).unwrap().clone(), objects.get(start).unwrap().clone(), ), }, _ => { objects[start..end].sort_by(comparator); let mid = start + object_span / 2; ( Arc::new(BVHNode::from_objects( &objects, start, mid, time_start, time_end, )) as Arc, Arc::new(BVHNode::from_objects( &objects, mid, end, time_start, time_end, )) as Arc, ) } }; let box_left = left .bounding_box(time_start, time_end) .expect("No bounding box in bvh_node constructor!"); let box_right = right .bounding_box(time_start, time_end) .expect("No bounding box in bvh_node constructor!"); BVHNode { left, right, aabb: box_left.surrounding_box(&box_right), } } fn box_compare(a: Arc, b: Arc, axis: Axis) -> cmp::Ordering { let box_a = a .bounding_box(0.0, 0.0) .expect("No bounding box in bvh_node constructor!"); let box_b = b .bounding_box(0.0, 0.0) .expect("No bounding box in bvh_node constructor!"); box_a .minimum .get(axis as usize) .unwrap() .total_cmp(box_b.minimum.get(axis as usize).unwrap()) } } impl Hittable for BVHNode { fn hit(&self, ray: &Ray, t_min: f64, t_max: f64) -> Option { if !self.aabb.hit(ray, t_min, t_max) { return None; } let hit_left = self.left.hit(ray, t_min, t_max); let hit_right_threshold = if let Some(hit_record_left) = &hit_left { hit_record_left.t } else { t_max }; let hit_right = self.right.hit(ray, t_min, hit_right_threshold); if let Some(_) = &hit_right { hit_right } else { hit_left } } fn bounding_box(&self, _: f64, _: f64) -> Option { Some(self.aabb.clone()) } }