Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/query/expression/src/aggregate/aggregate_hashtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ impl AggregateHashTable {
arena: Arc<Bump>,
need_init_entry: bool,
) -> Self {
debug_assert!(capacity.is_power_of_two());
let entries = if need_init_entry {
vec![Entry::default(); capacity]
} else {
Expand All @@ -110,6 +111,7 @@ impl AggregateHashTable {
entries,
count: 0,
capacity,
capacity_mask: capacity - 1,
},
config,
}
Expand Down
96 changes: 74 additions & 22 deletions src/query/expression/src/aggregate/hash_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,56 @@ pub(super) struct HashIndex {
pub entries: Vec<Entry>,
pub count: usize,
pub capacity: usize,
pub capacity_mask: usize,
}

const INCREMENT_BITS: usize = 5;

/// Derive an odd probing step from the high bits of the hash so the walk spans all slots.
///
/// this will generate a step in the range [1, 2^INCREMENT_BITS) based on hash and always odd.
#[inline(always)]
fn step(hash: u64) -> usize {
((hash >> (64 - INCREMENT_BITS)) as usize) | 1
}

/// Move to the next slot with wrap-around using the power-of-two capacity mask.
///
/// soundness: capacity is always a power of two, so mask is capacity - 1
#[inline(always)]
fn next_slot(slot: usize, hash: u64, mask: usize) -> usize {
(slot + step(hash)) & mask
}

#[inline(always)]
fn init_slot(hash: u64, capacity_mask: usize) -> usize {
hash as usize & capacity_mask
}

impl HashIndex {
pub fn with_capacity(capacity: usize) -> Self {
debug_assert!(capacity.is_power_of_two());
let capacity_mask = capacity - 1;
Self {
entries: vec![Entry::default(); capacity],
count: 0,
capacity,
capacity_mask,
}
}

fn init_slot(&self, hash: u64) -> usize {
hash as usize & (self.capacity - 1)
}

fn find_or_insert(&mut self, mut slot: usize, salt: u16) -> (usize, bool) {
fn find_or_insert(&mut self, mut slot: usize, hash: u64) -> (usize, bool) {
let salt = Entry::hash_to_salt(hash);
let entries = self.entries.as_mut_slice();
loop {
let entry = &mut entries[slot];
debug_assert!(entries.get(slot).is_some());
// SAFETY: slot is always in range
let entry = unsafe { entries.get_unchecked_mut(slot) };
if entry.is_occupied() {
if entry.get_salt() == salt {
return (slot, false);
} else {
slot += 1;
if slot >= self.capacity {
slot = 0;
}
slot = next_slot(slot, hash, self.capacity_mask);
continue;
}
} else {
Expand All @@ -59,13 +82,10 @@ impl HashIndex {
}

pub fn probe_slot(&mut self, hash: u64) -> usize {
let mut slot = self.init_slot(hash);
let entries = self.entries.as_mut_slice();
let mut slot = init_slot(hash, self.capacity_mask);
while entries[slot].is_occupied() {
slot += 1;
if slot >= self.capacity {
slot = 0;
}
slot = next_slot(slot, hash, self.capacity_mask);
}
slot as _
}
Expand Down Expand Up @@ -159,8 +179,9 @@ impl HashIndex {
slots.extend(
state.group_hashes[..row_count]
.iter()
.map(|hash| self.init_slot(*hash)),
.map(|hash| init_slot(*hash, self.capacity_mask)),
);
let capacity_mask = self.capacity_mask;

let mut new_group_count = 0;
let mut remaining_entries = row_count;
Expand All @@ -176,7 +197,7 @@ impl HashIndex {
let hash = state.group_hashes[row];

let is_new;
(*slot, is_new) = self.find_or_insert(*slot, Entry::hash_to_salt(hash));
(*slot, is_new) = self.find_or_insert(*slot, hash);

if is_new {
state.empty_vector[new_entry_count] = row;
Expand Down Expand Up @@ -217,13 +238,11 @@ impl HashIndex {
no_match_count = adapter.compare(state, need_compare_count, no_match_count);
}

// 5. Linear probing, just increase iter_times
// 5. Linear probing with hash-derived step
for row in state.no_match_vector[..no_match_count].iter().copied() {
let slot = &mut slots[row];
*slot += 1;
if *slot >= self.capacity {
*slot = 0;
}
let hash = state.group_hashes[row];
*slot = next_slot(*slot, hash, capacity_mask);
}
remaining_entries = no_match_count;
}
Expand Down Expand Up @@ -262,6 +281,7 @@ impl<'a> TableAdapter for AdapterImpl<'a> {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::collections::HashSet;

use super::*;
use crate::ProbeState;
Expand Down Expand Up @@ -405,6 +425,38 @@ mod tests {
}
}

#[test]
fn test_probe_walk_covers_full_capacity() {
// This test make sure that we can always cover all slots in the table
let capacity = 16;
let capacity_mask = capacity - 1;

for high_bits in 0u64..(1 << INCREMENT_BITS) {
let hash = high_bits << (64 - INCREMENT_BITS);
let mut slot = init_slot(hash, capacity_mask);
let mut visited = HashSet::with_capacity(capacity);

for _ in 0..capacity {
assert!(
visited.insert(slot),
"hash {hash:#x} revisited slot {slot} before covering the table"
);
slot = next_slot(slot, hash, capacity_mask);
}

assert_eq!(
capacity,
visited.len(),
"hash {hash:#x} failed to cover every slot for capacity {capacity}"
);
assert_eq!(
init_slot(hash, capacity_mask),
slot,
"hash {hash:#x} walk did not return to its start after {capacity} steps"
);
}
}

#[test]
fn test_hash_index() {
TestCase {
Expand Down
Loading