· 6 years ago · Oct 19, 2019, 09:38 AM
1/*
2 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
3 holder of all proprietary rights on this computer program.
4 You can only use this computer program if you have closed
5 a license agreement with MPG or you get the right to use the computer
6 program from someone who is authorized to grant you that right.
7 Any use of the computer program without a valid license is prohibited and
8 liable to prosecution.
9
10 Copyright©2019 Max-Planck-Gesellschaft zur Förderung
11 der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
12 for Intelligent Systems and the Max Planck Institute for Biological
13 Cybernetics. All rights reserved.
14
15 Contact: ps-license@tuebingen.mpg.de
16*/
17
18#include <ATen/ATen.h>
19
20#include <cuda.h>
21#include "device_launch_parameters.h"
22#include <cuda_runtime.h>
23
24#include <thrust/iterator/counting_iterator.h>
25#include <thrust/functional.h>
26#include <thrust/remove.h>
27#include <thrust/host_vector.h>
28#include <thrust/device_vector.h>
29#include <thrust/reduce.h>
30#include <thrust/sort.h>
31
32#include <vector>
33#include <iostream>
34#include <string>
35#include <type_traits>
36
37#include "double_vec_ops.h"
38#include "helper_math.h"
39
40// Size of the stack used to traverse the Bounding Volume Hierarchy tree
41#ifndef STACK_SIZE
42#define STACK_SIZE 64
43#endif /* ifndef STACK_SIZE */
44
45// Upper bound for the number of possible collisions
46#ifndef MAX_COLLISIONS
47#define MAX_COLLISIONS 16
48#endif
49
50#ifndef EPSILON
51#define EPSILON 1e-16
52#endif /* ifndef EPSILON */
53
54// Number of threads per block for CUDA kernel launch
55#ifndef NUM_THREADS
56#define NUM_THREADS 128
57#endif
58
59#ifndef COLLISION_ORDERING
60#define COLLISION_ORDERING 1
61#endif
62
63#ifndef FORCE_INLINE
64#define FORCE_INLINE 1
65#endif /* ifndef FORCE_INLINE */
66
67#ifndef ERROR_CHECKING
68#define ERROR_CHECKING 1
69#endif /* ifndef ERROR_CHECKING */
70
71// Macro for checking cuda errors following a cuda launch or api call
72#if ERROR_CHECKING == 1
73#define cudaCheckError() \
74 { \
75 cudaDeviceSynchronize(); \
76 cudaError_t e = cudaGetLastError(); \
77 if (e != cudaSuccess) { \
78 printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \
79 cudaGetErrorString(e)); \
80 exit(0); \
81 } \
82 }
83#else
84#define cudaCheckError()
85#endif
86
87typedef unsigned int MortonCode;
88
89template <typename T>
90using vec3 = typename std::conditional<std::is_same<T, float>::value, float3,
91 double3>::type;
92
93template <typename T>
94using vec2 = typename std::conditional<std::is_same<T, float>::value, float2,
95 double2>::type;
96
97template <typename T>
98std::ostream &operator<<(std::ostream &os, const vec3<T> &x) {
99 os << x.x << ", " << x.y << ", " << x.z;
100 return os;
101}
102
103
104std::ostream &operator<<(std::ostream &os, const vec3<float> &x) {
105 os << x.x << ", " << x.y << ", " << x.z;
106 return os;
107}
108
109std::ostream &operator<<(std::ostream &os, const vec3<double> &x) {
110 os << x.x << ", " << x.y << ", " << x.z;
111 return os;
112}
113
114template <typename T>
115std::ostream &operator<<(std::ostream &os, vec3<T> x) {
116 os << x.x << ", " << x.y << ", " << x.z;
117 return os;
118}
119
120__host__ __device__ inline double3 fmin(const double3 &a, const double3 &b) {
121 return make_double3(fmin(a.x, b.x), fmin(a.y, b.y), fmin(a.z, b.z));
122}
123
124__host__ __device__ inline double3 fmax(const double3 &a, const double3 &b) {
125 return make_double3(fmax(a.x, b.x), fmax(a.y, b.y), fmax(a.z, b.z));
126}
127
128struct is_valid_cnt : public thrust::unary_function<long2, int> {
129public:
130 __host__ __device__ int operator()(long2 vec) const {
131 return vec.x >= 0 && vec.y >= 0;
132 }
133};
134
135template <typename T>
136__host__ __device__ __forceinline__ float vec_abs_diff(const vec3<T> &vec1,
137 const vec3<T> &vec2) {
138 return fabs(vec1.x - vec2.x) + fabs(vec1.y - vec2.y) + fabs(vec1.z - vec2.z);
139}
140
141template <typename T>
142__host__ __device__ __forceinline__ float vec_sq_diff(const vec3<T> &vec1,
143 const vec3<T> &vec2) {
144 return dot(vec1 - vec2, vec1 - vec2);
145}
146
147template <typename T> struct AABB {
148public:
149 __host__ __device__ AABB() {
150 min_t.x = std::is_same<T, float>::value ? FLT_MAX : DBL_MAX;
151 min_t.y = std::is_same<T, float>::value ? FLT_MAX : DBL_MAX;
152 min_t.z = std::is_same<T, float>::value ? FLT_MAX : DBL_MAX;
153
154 max_t.x = std::is_same<T, float>::value ? -FLT_MAX : -DBL_MAX;
155 max_t.y = std::is_same<T, float>::value ? -FLT_MAX : -DBL_MAX;
156 max_t.z = std::is_same<T, float>::value ? -FLT_MAX : -DBL_MAX;
157 };
158
159 __host__ __device__ AABB(const vec3<T> &min_t, const vec3<T> &max_t)
160 : min_t(min_t), max_t(max_t){};
161 __host__ __device__ ~AABB(){};
162
163 __host__ __device__ AABB(T min_t_x, T min_t_y, T min_t_z, T max_t_x,
164 T max_t_y, T max_t_z) {
165 min_t.x = min_t_x;
166 min_t.y = min_t_y;
167 min_t.z = min_t_z;
168 max_t.x = max_t_x;
169 max_t.y = max_t_y;
170 max_t.z = max_t_z;
171 }
172
173 __host__ __device__ AABB<T> operator+(const AABB<T> &bbox2) const {
174 return AABB<T>(
175 min(this->min_t.x, bbox2.min_t.x), min(this->min_t.y, bbox2.min_t.y),
176 min(this->min_t.z, bbox2.min_t.z), max(this->max_t.x, bbox2.max_t.x),
177 max(this->max_t.y, bbox2.max_t.y), max(this->max_t.z, bbox2.max_t.z));
178 };
179
180 __host__ __device__ T operator*(const AABB<T> &bbox2) const {
181 return (min(this->max_t.x, bbox2.max_t.x) -
182 max(this->min_t.x, bbox2.min_t.x)) *
183 (min(this->max_t.y, bbox2.max_t.y) -
184 max(this->min_t.y, bbox2.min_t.y)) *
185 (min(this->max_t.z, bbox2.max_t.z) -
186 max(this->min_t.z, bbox2.min_t.z));
187 };
188
189 vec3<T> min_t;
190 vec3<T> max_t;
191};
192
193template <typename T>
194std::ostream &operator<<(std::ostream &os, const AABB<T> &x) {
195 os << x.min_t << std::endl;
196 os << x.max_t << std::endl;
197 return os;
198}
199
200template <typename T> struct MergeAABB {
201
202public:
203 __host__ __device__ MergeAABB(){};
204
205 // Create an operator Struct that will be used by thrust::reduce
206 // to calculate the bounding box of the scene.
207 __host__ __device__ AABB<T> operator()(const AABB<T> &bbox1,
208 const AABB<T> &bbox2) {
209 return bbox1 + bbox2;
210 };
211};
212
213template <typename T> struct Triangle {
214public:
215 vec3<T> v0;
216 vec3<T> v1;
217 vec3<T> v2;
218
219 __host__ __device__ Triangle(const vec3<T> &vertex0, const vec3<T> &vertex1,
220 const vec3<T> &vertex2)
221 : v0(vertex0), v1(vertex1), v2(vertex2){};
222
223 __host__ __device__ AABB<T> ComputeBBox() {
224 return AABB<T>(min(v0.x, min(v1.x, v2.x)), min(v0.y, min(v1.y, v2.y)),
225 min(v0.z, min(v1.z, v2.z)), max(v0.x, max(v1.x, v2.x)),
226 max(v0.y, max(v1.y, v2.y)), max(v0.z, max(v1.z, v2.z)));
227 }
228};
229
230template <typename T> using TrianglePtr = Triangle<T> *;
231
232template <typename T>
233std::ostream &operator<<(std::ostream &os, const Triangle<T> &x) {
234 os << x.v0 << std::endl;
235 os << x.v1 << std::endl;
236 os << x.v2 << std::endl;
237 return os;
238}
239
240template <typename T>
241__global__ void ComputeTriBoundingBoxes(Triangle<T> *triangles,
242 int num_triangles, AABB<T> *bboxes) {
243 int idx = threadIdx.x + blockDim.x * blockIdx.x;
244 if (idx < num_triangles) {
245 bboxes[idx] = triangles[idx].ComputeBBox();
246 }
247}
248
249template <typename T>
250__device__ inline vec2<T> isect_interval(const vec3<T> &sep_axis,
251 const Triangle<T> &tri) {
252 // Check the separating sep_axis versus the first point of the triangle
253 T proj_distance = dot(sep_axis, tri.v0);
254
255 vec2<T> interval;
256 interval.x = proj_distance;
257 interval.y = proj_distance;
258
259 proj_distance = dot(sep_axis, tri.v1);
260 interval.x = min(interval.x, proj_distance);
261 interval.y = max(interval.y, proj_distance);
262
263 proj_distance = dot(sep_axis, tri.v2);
264 interval.x = min(interval.x, proj_distance);
265 interval.y = max(interval.y, proj_distance);
266
267 return interval;
268}
269
270template <typename T>
271__device__ inline bool TriangleTriangleOverlap(const Triangle<T> &tri1,
272 const Triangle<T> &tri2,
273 const vec3<T> &sep_axis) {
274 // Calculate the projected segment of each triangle on the separating
275 // axis.
276 vec2<T> tri1_interval = isect_interval(sep_axis, tri1);
277 vec2<T> tri2_interval = isect_interval(sep_axis, tri2);
278
279 // In order for the triangles to overlap then there must exist an
280 // intersection of the two intervals
281 return (tri1_interval.x <= tri2_interval.y) &&
282 (tri1_interval.y >= tri2_interval.x);
283}
284
285template <typename T>
286__device__ bool TriangleTriangleIsectSepAxis(const Triangle<T> &tri1,
287 const Triangle<T> &tri2) {
288 // Calculate the edges and the normal for the first triangle
289 vec3<T> tri1_edge0 = tri1.v1 - tri1.v0;
290 vec3<T> tri1_edge1 = tri1.v2 - tri1.v0;
291 vec3<T> tri1_edge2 = tri1.v2 - tri1.v1;
292 vec3<T> tri1_normal = cross(tri1_edge1, tri1_edge2);
293
294 // Calculate the edges and the normal for the second triangle
295 vec3<T> tri2_edge0 = tri2.v1 - tri2.v0;
296 vec3<T> tri2_edge1 = tri2.v2 - tri2.v0;
297 vec3<T> tri2_edge2 = tri2.v2 - tri2.v1;
298 vec3<T> tri2_normal = cross(tri2_edge1, tri2_edge2);
299
300 // If the triangles are coplanar then the first 11 cases are all the same,
301 // since the cross product will just give us the normal vector
302 vec3<T> axes[17] = {
303 tri1_normal,
304 tri2_normal,
305 cross(tri1_edge0, tri2_edge0),
306 cross(tri1_edge0, tri2_edge1),
307 cross(tri1_edge0, tri2_edge2),
308 cross(tri1_edge1, tri2_edge0),
309 cross(tri1_edge1, tri2_edge1),
310 cross(tri1_edge1, tri2_edge2),
311 cross(tri1_edge2, tri2_edge0),
312 cross(tri1_edge2, tri2_edge1),
313 cross(tri1_edge2, tri2_edge2),
314 // Triangles are coplanar
315 // Check the axis created by the normal of the triangle and the edges of
316 // both triangles.
317 cross(tri1_normal, tri1_edge0),
318 cross(tri1_normal, tri1_edge1),
319 cross(tri1_normal, tri1_edge2),
320 cross(tri1_normal, tri2_edge0),
321 cross(tri1_normal, tri2_edge1),
322 cross(tri1_normal, tri2_edge2),
323 };
324
325 bool isect_flag = true;
326#pragma unroll
327 for (int i = 0; i < 17; ++i) {
328 isect_flag = isect_flag && (TriangleTriangleOverlap(tri1, tri2, axes[i]));
329 }
330
331 return isect_flag;
332}
333
334// Returns true if the triangles share one or multiple vertices
335template <typename T>
336__device__
337#if FORCE_INLINE == 1
338 __forceinline__
339#endif
340bool
341shareVertex(const Triangle<T> &tri1, const Triangle<T> &tri2) {
342
343 return (tri1.v0.x == tri2.v0.x && tri1.v0.y == tri2.v0.y && tri1.v0.z == tri2.v0.z) ||
344 (tri1.v0.x == tri2.v1.x && tri1.v0.y == tri2.v1.y && tri1.v0.z == tri2.v1.z) ||
345 (tri1.v0.x == tri2.v2.x && tri1.v0.y == tri2.v2.y && tri1.v0.z == tri2.v2.z) ||
346 (tri1.v1.x == tri2.v0.x && tri1.v1.y == tri2.v0.y && tri1.v1.z == tri2.v0.z) ||
347 (tri1.v1.x == tri2.v1.x && tri1.v1.y == tri2.v1.y && tri1.v1.z == tri2.v1.z) ||
348 (tri1.v1.x == tri2.v2.x && tri1.v1.y == tri2.v2.y && tri1.v1.z == tri2.v2.z) ||
349 (tri1.v2.x == tri2.v0.x && tri1.v2.y == tri2.v0.y && tri1.v2.z == tri2.v0.z) ||
350 (tri1.v2.x == tri2.v1.x && tri1.v2.y == tri2.v1.y && tri1.v2.z == tri2.v1.z) ||
351 (tri1.v2.x == tri2.v2.x && tri1.v2.y == tri2.v2.y && tri1.v2.z == tri2.v2.z);
352}
353
354template <typename T>
355__global__ void checkTriangleIntersections(long2 *collisions,
356 Triangle<T> *triangles,
357 int num_cand_collisions,
358 int num_triangles) {
359 int idx = threadIdx.x + blockDim.x * blockIdx.x;
360 if (idx < num_cand_collisions) {
361 int first_tri_idx = collisions[idx].x;
362 int second_tri_idx = collisions[idx].y;
363
364 Triangle<T> tri1 = triangles[first_tri_idx];
365 Triangle<T> tri2 = triangles[second_tri_idx];
366 bool do_collide = TriangleTriangleIsectSepAxis<T>(tri1, tri2) &&
367 !shareVertex<T>(tri1, tri2);
368 if (do_collide) {
369 collisions[idx] = make_long2(first_tri_idx, second_tri_idx);
370 } else {
371 collisions[idx] = make_long2(-1, -1);
372 }
373 }
374 return;
375}
376
377template <typename T> struct BVHNode {
378public:
379 AABB<T> bbox;
380
381 BVHNode<T> *left;
382 BVHNode<T> *right;
383 BVHNode<T> *parent;
384 // Stores the rightmost leaf node that can be reached from the current
385 // node.
386 BVHNode<T> *rightmost;
387
388 __host__ __device__ inline bool isLeaf() { return !left && !right; };
389
390 // The index of the object contained in the node
391 int idx;
392};
393
394template <typename T> using BVHNodePtr = BVHNode<T> *;
395
396template <typename T>
397__device__
398#if FORCE_INLINE == 1
399 __forceinline__
400#endif
401 bool
402 checkOverlap(const AABB<T> &bbox1, const AABB<T> &bbox2) {
403 return (bbox1.min_t.x <= bbox2.max_t.x) && (bbox1.max_t.x >= bbox2.min_t.x) &&
404 (bbox1.min_t.y <= bbox2.max_t.y) && (bbox1.max_t.y >= bbox2.min_t.y) &&
405 (bbox1.min_t.z <= bbox2.max_t.z) && (bbox1.max_t.z >= bbox2.min_t.z);
406}
407
408template <typename T>
409__device__ int traverseBVH(long2 *collisionIndices, BVHNodePtr<T> root,
410 const AABB<T> &queryAABB, int queryObjectIdx,
411 BVHNodePtr<T> leaf, int max_collisions,
412 int *counter) {
413 int num_collisions = 0;
414 // Allocate traversal stack from thread-local memory,
415 // and push NULL to indicate that there are no postponed nodes.
416 BVHNodePtr<T> stack[STACK_SIZE];
417 BVHNodePtr<T> *stackPtr = stack;
418 *stackPtr++ = nullptr; // push
419
420 // Traverse nodes starting from the root.
421 BVHNodePtr<T> node = root;
422 do {
423 // Check each child node for overlap.
424 BVHNodePtr<T> childL = node->left;
425 BVHNodePtr<T> childR = node->right;
426 bool overlapL = checkOverlap<T>(queryAABB, childL->bbox);
427 bool overlapR = checkOverlap<T>(queryAABB, childR->bbox);
428
429#if COLLISION_ORDERING == 1
430 /*
431 If we do not impose any order, then all potential collisions will be
432 reported twice (i.e. the query object with the i-th colliding object
433 and the i-th colliding object with the query). In order to avoid
434 this, we impose an ordering, saying that an object can collide with
435 another only if it comes before it in the tree. For example, if we
436 are checking for the object 10, there is no need to check the subtree
437 that has the objects that are before it, since they will already have
438 been checked.
439 */
440 if (leaf >= childL->rightmost) {
441 overlapL = false;
442 }
443 if (leaf >= childR->rightmost) {
444 overlapR = false;
445 }
446#endif
447
448 // Query overlaps a leaf node => report collision.
449 if (overlapL && childL->isLeaf()) {
450 // Append the collision to the main array
451 // Increase the number of detection collisions
452 // num_collisions++;
453 int coll_idx = atomicAdd(counter, 1);
454 collisionIndices[coll_idx] =
455 // collisionIndices[num_collisions % max_collisions] =
456 // *collisionIndices++ =
457 make_long2(min(queryObjectIdx, childL->idx),
458 max(queryObjectIdx, childL->idx));
459 num_collisions++;
460 }
461
462 if (overlapR && childR->isLeaf()) {
463 int coll_idx = atomicAdd(counter, 1);
464 collisionIndices[coll_idx] = make_long2(
465 // min(queryObjectIdx, childR->idx),
466 // max(queryObjectIdx, childR->idx));
467 // collisionIndices[num_collisions % max_collisions] = make_long2(
468 min(queryObjectIdx, childR->idx), max(queryObjectIdx, childR->idx));
469 num_collisions++;
470 }
471
472 // Query overlaps an internal node => traverse.
473 bool traverseL = (overlapL && !childL->isLeaf());
474 bool traverseR = (overlapR && !childR->isLeaf());
475
476 if (!traverseL && !traverseR) {
477 node = *--stackPtr; // pop
478 }
479 else {
480 node = (traverseL) ? childL : childR;
481 if (traverseL && traverseR) {
482 *stackPtr++ = childR; // push
483 }
484 }
485 } while (node != nullptr);
486
487 return num_collisions;
488}
489
490template <typename T>
491__global__ void findPotentialCollisions(long2 *collisionIndices,
492 BVHNodePtr<T> root,
493 BVHNodePtr<T> leaves, int *triangle_ids,
494 int num_primitives,
495 int max_collisions, int *counter) {
496 int idx = threadIdx.x + blockDim.x * blockIdx.x;
497 if (idx < num_primitives) {
498
499 BVHNodePtr<T> leaf = leaves + idx;
500 int triangle_id = triangle_ids[idx];
501 int num_collisions =
502 traverseBVH<T>(collisionIndices, root, leaf->bbox, triangle_id,
503 leaf, max_collisions, counter);
504 }
505 return;
506}
507
508// Expands a 10-bit integer into 30 bits
509// by inserting 2 zeros after each bit.
510__device__
511#if FORCE_INLINE == 1
512 __forceinline__
513#endif
514 MortonCode
515 expandBits(MortonCode v) {
516 // Shift 16
517 v = (v * 0x00010001u) & 0xFF0000FFu;
518 // Shift 8
519 v = (v * 0x00000101u) & 0x0F00F00Fu;
520 // Shift 4
521 v = (v * 0x00000011u) & 0xC30C30C3u;
522 // Shift 2
523 v = (v * 0x00000005u) & 0x49249249u;
524 return v;
525}
526
527// Calculates a 30-bit Morton code for the
528// given 3D point located within the unit cube [0,1].
529template <typename T>
530__device__
531#if FORCE_INLINE == 1
532 __forceinline__
533#endif
534 MortonCode
535 morton3D(T x, T y, T z) {
536 x = min(max(x * 1024.0f, 0.0f), 1023.0f);
537 y = min(max(y * 1024.0f, 0.0f), 1023.0f);
538 z = min(max(z * 1024.0f, 0.0f), 1023.0f);
539 MortonCode xx = expandBits((MortonCode)x);
540 MortonCode yy = expandBits((MortonCode)y);
541 MortonCode zz = expandBits((MortonCode)z);
542 return xx * 4 + yy * 2 + zz;
543}
544
545template <typename T>
546__global__ void ComputeMortonCodes(Triangle<T> *triangles, int num_triangles,
547 AABB<T> *scene_bb,
548 MortonCode *morton_codes) {
549 int idx = threadIdx.x + blockDim.x * blockIdx.x;
550 if (idx < num_triangles) {
551 // Fetch the current triangle
552 Triangle<T> tri = triangles[idx];
553 vec3<T> centroid = (tri.v0 + tri.v1 + tri.v2) / (T)3.0;
554
555 T x = (centroid.x - scene_bb->min_t.x) /
556 (scene_bb->max_t.x - scene_bb->min_t.x);
557 T y = (centroid.y - scene_bb->min_t.y) /
558 (scene_bb->max_t.y - scene_bb->min_t.y);
559 T z = (centroid.z - scene_bb->min_t.z) /
560 (scene_bb->max_t.z - scene_bb->min_t.z);
561
562 morton_codes[idx] = morton3D<T>(x, y, z);
563 }
564 return;
565}
566
567__device__
568#if FORCE_INLINE == 1
569 __forceinline__
570#endif
571 int
572 LongestCommonPrefix(int i, int j, MortonCode *morton_codes,
573 int num_triangles, int *triangle_ids) {
574 // This function will be called for i - 1, i, i + 1, so we might go beyond
575 // the array limits
576 if (i < 0 || i > num_triangles - 1 || j < 0 || j > num_triangles - 1)
577 return -1;
578
579 MortonCode key1 = morton_codes[i];
580 MortonCode key2 = morton_codes[j];
581
582 if (key1 == key2) {
583 // Duplicate key:__clzll(key1 ^ key2) will be equal to the number of
584 // bits in key[1, 2]. Add the number of leading zeros between the
585 // indices
586 return __clz(key1 ^ key2) + __clz(triangle_ids[i] ^ triangle_ids[j]);
587 } else {
588 // Keys are different
589 return __clz(key1 ^ key2);
590 }
591}
592
593template <typename T>
594__global__ void BuildRadixTree(MortonCode *morton_codes, int num_triangles,
595 int *triangle_ids, BVHNodePtr<T> internal_nodes,
596 BVHNodePtr<T> leaf_nodes) {
597 int idx = blockDim.x * blockIdx.x + threadIdx.x;
598 if (idx >= num_triangles - 1)
599 return;
600
601 int delta_next = LongestCommonPrefix(idx, idx + 1, morton_codes,
602 num_triangles, triangle_ids);
603 int delta_last = LongestCommonPrefix(idx, idx - 1, morton_codes,
604 num_triangles, triangle_ids);
605 // Find the direction of the range
606 int direction = delta_next - delta_last >= 0 ? 1 : -1;
607
608 int delta_min = LongestCommonPrefix(idx, idx - direction, morton_codes,
609 num_triangles, triangle_ids);
610
611 // Do binary search to compute the upper bound for the length of the range
612 int lmax = 2;
613 while (LongestCommonPrefix(idx, idx + lmax * direction, morton_codes,
614 num_triangles, triangle_ids) > delta_min) {
615 lmax *= 2;
616 }
617
618 // Use binary search to find the other end.
619 int l = 0;
620 int divider = 2;
621 for (int t = lmax / divider; t >= 1; divider *= 2) {
622 if (LongestCommonPrefix(idx, idx + (l + t) * direction, morton_codes,
623 num_triangles, triangle_ids) > delta_min) {
624 l = l + t;
625 }
626 t = lmax / divider;
627 }
628 int j = idx + l * direction;
629
630 // Find the length of the longest common prefix for the current node
631 int node_delta =
632 LongestCommonPrefix(idx, j, morton_codes, num_triangles, triangle_ids);
633 int s = 0;
634 divider = 2;
635 // Search for the split position using binary search.
636 for (int t = (l + (divider - 1)) / divider; t >= 1; divider *= 2) {
637 if (LongestCommonPrefix(idx, idx + (s + t) * direction, morton_codes,
638 num_triangles, triangle_ids) > node_delta) {
639 s = s + t;
640 }
641 t = (l + (divider - 1)) / divider;
642 }
643 // gamma in the Karras paper
644 int split = idx + s * direction + min(direction, 0);
645
646 // Assign the parent and the left, right children for the current node
647 BVHNodePtr<T> curr_node = internal_nodes + idx;
648 if (min(idx, j) == split) {
649 curr_node->left = leaf_nodes + split;
650 (leaf_nodes + split)->parent = curr_node;
651 } else {
652 curr_node->left = internal_nodes + split;
653 (internal_nodes + split)->parent = curr_node;
654 }
655 if (max(idx, j) == split + 1) {
656 curr_node->right = leaf_nodes + split + 1;
657 (leaf_nodes + split + 1)->parent = curr_node;
658 } else {
659 curr_node->right = internal_nodes + split + 1;
660 (internal_nodes + split + 1)->parent = curr_node;
661 }
662}
663
664template <typename T>
665__global__ void CreateHierarchy(BVHNodePtr<T> internal_nodes,
666 BVHNodePtr<T> leaf_nodes, int num_triangles,
667 Triangle<T> *triangles, int *triangle_ids,
668 int *atomic_counters) {
669 int idx = blockDim.x * blockIdx.x + threadIdx.x;
670 if (idx >= num_triangles)
671 return;
672
673 BVHNodePtr<T> leaf = leaf_nodes + idx;
674 // Assign the index to the primitive
675 leaf->idx = triangle_ids[idx];
676
677 Triangle<T> tri = triangles[triangle_ids[idx]];
678 // Assign the bounding box of the triangle to the leaves
679 leaf->bbox = tri.ComputeBBox();
680 leaf->rightmost = leaf;
681
682 BVHNodePtr<T> curr_node = leaf->parent;
683 int current_idx = curr_node - internal_nodes;
684
685 // Increment the atomic counter
686 int curr_counter = atomicAdd(atomic_counters + current_idx, 1);
687 while (true) {
688 // atomicAdd returns the old value at the specified address. Thus the
689 // first thread to reach this point will immediately return
690 if (curr_counter == 0)
691 break;
692
693 // Calculate the bounding box of the current node as the union of the
694 // bounding boxes of its children.
695 AABB<T> left_bb = curr_node->left->bbox;
696 AABB<T> right_bb = curr_node->right->bbox;
697 curr_node->bbox = left_bb + right_bb;
698 // Store a pointer to the right most node that can be reached from this
699 // internal node.
700 curr_node->rightmost =
701 curr_node->left->rightmost > curr_node->right->rightmost
702 ? curr_node->left->rightmost
703 : curr_node->right->rightmost;
704
705 // If we have reached the root break
706 if (curr_node == internal_nodes)
707 break;
708
709 // Proceed to the parent of the node
710 curr_node = curr_node->parent;
711 // Calculate its position in the flat array
712 current_idx = curr_node - internal_nodes;
713 // Update the visitation counter
714 curr_counter = atomicAdd(atomic_counters + current_idx, 1);
715 }
716
717 return;
718}
719
720template <typename T>
721void buildBVH(BVHNodePtr<T> internal_nodes, BVHNodePtr<T> leaf_nodes,
722 Triangle<T>* __restrict__ triangles,
723 thrust::device_vector<int> *triangle_ids, int num_triangles,
724 int batch_size) {
725
726
727
728 thrust::device_vector<AABB<T>> bounding_boxes(num_triangles);
729
730 int blockSize = NUM_THREADS;
731 int gridSize = (num_triangles + blockSize - 1) / blockSize;
732
733 // Compute the bounding box for all the triangles
734
735 ComputeTriBoundingBoxes<T><<<gridSize, blockSize>>>(
736 triangles, num_triangles, bounding_boxes.data().get());
737
738
739 cudaCheckError();
740
741
742
743
744 // Compute the union of all the bounding boxes
745 AABB<T> host_scene_bb = thrust::reduce(
746 bounding_boxes.begin(), bounding_boxes.end(), AABB<T>(), MergeAABB<T>());
747
748
749 cudaCheckError();
750
751
752
753
754
755 // TODO: Custom reduction ?
756 // Copy the bounding box back to the GPU
757 AABB<T> *scene_bb_ptr;
758 cudaMalloc(&scene_bb_ptr, sizeof(AABB<T>));
759 cudaMemcpy(scene_bb_ptr, &host_scene_bb, sizeof(AABB<T>),
760 cudaMemcpyHostToDevice);
761
762 thrust::device_vector<MortonCode> morton_codes(num_triangles);
763
764
765
766 // Compute the morton codes for the centroids of all the primitives
767 ComputeMortonCodes<T><<<gridSize, blockSize>>>(
768 triangles, num_triangles, scene_bb_ptr,
769 morton_codes.data().get());
770
771
772 cudaCheckError();
773
774
775
776
777
778
779 // Construct an array of triangle ids.
780 thrust::sequence(triangle_ids->begin(), triangle_ids->end());
781
782
783 // Sort the triangles according to the morton code
784
785
786 try {
787
788 thrust::sort_by_key(morton_codes.begin(), morton_codes.end(),
789 triangle_ids->begin());
790
791
792 } catch (thrust::system_error e) {
793 std::cout << "Error inside sort: " << e.what() << std::endl;
794 }
795
796
797 // Construct the radix tree using the sorted morton code sequence
798 BuildRadixTree<T><<<gridSize, blockSize>>>(
799 morton_codes.data().get(), num_triangles, triangle_ids->data().get(),
800 internal_nodes, leaf_nodes);
801
802
803 cudaCheckError();
804
805
806 // Create an array that contains the atomic counters for each node in the
807 // tree
808 thrust::device_vector<int> counters(num_triangles);
809
810
811 // Build the Bounding Volume Hierarchy in parallel from the leaves to the
812 // root
813 CreateHierarchy<T><<<gridSize, blockSize>>>(
814 internal_nodes, leaf_nodes, num_triangles, triangles,
815 triangle_ids->data().get(), counters.data().get());
816
817 cudaCheckError();
818
819
820
821
822
823 cudaFree(scene_bb_ptr);
824 return;
825}
826
827void bvh_cuda_forward(at::Tensor triangles, at::Tensor *collision_tensor_ptr,
828 int max_collisions = 16) {
829 const auto batch_size = triangles.size(0);
830 const auto num_triangles = triangles.size(1);
831
832 thrust::device_vector<int> triangle_ids(num_triangles);
833
834 int blockSize = NUM_THREADS;
835 int gridSize = (num_triangles + blockSize - 1) / blockSize;
836
837 thrust::device_vector<long2> collisionIndices(num_triangles * max_collisions);
838
839
840
841 // int *counter;
842 thrust::device_vector<int> collision_idx_cnt(batch_size);
843 thrust::fill(collision_idx_cnt.begin(), collision_idx_cnt.end(), 0);
844
845 // Construct the bvh tree
846 AT_DISPATCH_FLOATING_TYPES(triangles.type(), "bvh_tree_building", ([&] {
847 thrust::device_vector<BVHNode<scalar_t>> leaf_nodes(num_triangles);
848 thrust::device_vector<BVHNode<scalar_t>> internal_nodes(num_triangles - 1);
849 auto triangle_float_ptr = triangles.data<scalar_t>();
850
851 for (int bidx = 0; bidx < batch_size; ++bidx) {
852
853 Triangle<scalar_t> *triangles_ptr =
854 (TrianglePtr<scalar_t>)triangle_float_ptr + num_triangles * bidx;
855
856 thrust::fill(collisionIndices.begin(), collisionIndices.end(), make_long2(-1, -1));
857
858
859 buildBVH<scalar_t>(internal_nodes.data().get(), leaf_nodes.data().get(), triangles_ptr, &triangle_ids, num_triangles, batch_size);
860
861 // std::cout << tmp[0].right->bbox << std::endl;
862
863 findPotentialCollisions<scalar_t><<<gridSize, blockSize>>>(
864 collisionIndices.data().get(),
865 internal_nodes.data().get(),
866 leaf_nodes.data().get(), triangle_ids.data().get(), num_triangles,
867 max_collisions, &collision_idx_cnt.data().get()[bidx]);
868 cudaDeviceSynchronize();
869
870
871 cudaCheckError();
872
873
874
875
876 // Calculate the number of potential collisions
877
878
879 int num_cand_collisions = thrust::reduce(thrust::make_transform_iterator(collisionIndices.begin(), is_valid_cnt()), thrust::make_transform_iterator(collisionIndices.end(), is_valid_cnt()));
880 // Keep only the pairs of ids where a bounding box to bounding box
881 // collision was detected.
882 thrust::device_vector<long2> collisions(num_cand_collisions, make_long2(-1, -1));
883 thrust::copy_if(collisionIndices.begin(), collisionIndices.end(),collisions.begin(), is_valid_cnt());
884
885 cudaCheckError();
886
887
888
889
890 }
891 }
892 ));
893
894}