EvolveDev
Theme toggle is loading

Implementing Decision Trees in Rust

Learn how to implement a decision tree algorithm in Rust from scratch. This guide covers building the data structures, finding the best split, and testing the decision tree. Perfect for those looking to understand decision trees and Rust programming.

Published: Aug 06, 2024

Implementing Decision Trees in Rust

Decision trees are a popular machine learning algorithm used for both classification and regression tasks. In this blog, we will implement a simple decision tree algorithm in Rust.

Understanding Decision Trees

A decision tree is a flowchart-like tree structure where:

The goal is to split the data into subsets based on the feature that provides the best separation, often using metrics like Gini impurity or information gain.

Setting Up the Rust Project

First, ensure you have Rust installed. You can check by running rustc --version in your terminal. Create a new Rust project using Cargo:

cargo new decision_tree
cd decision_tree

This creates a new directory called decision_tree with a basic Rust project setup.

Implementing the Decision Tree Algorithm

We'll start by implementing the decision tree algorithm. The core components of our decision tree will include:

Define the Data Structures

Open src/main.rs and define the basic structures for our decision tree:

use std::collections::HashMap;
 
// Define a structure for the decision tree node
#[derive(Debug, Clone)]
enum Node {
    Internal { feature: usize, value: f64, left: Box<Node>, right: Box<Node> },
    Leaf { label: usize },
}
 
// Define the DecisionTree structure
#[derive(Debug, Clone)]
struct DecisionTree {
    root: Node,
}
 
impl DecisionTree {
    // Method to predict the label for a given data point
    fn predict(&self, features: &[f64]) -> usize {
        Self::predict_node(&self.root, features)
    }
 
    // Recursive function to traverse the tree and make a prediction
    fn predict_node(node: &Node, features: &[f64]) -> usize {
        match node {
            Node::Internal { feature, value, left, right } => {
                if features[*feature] <= *value {
                    Self::predict_node(left, features)
                } else {
                    Self::predict_node(right, features)
                }
            }
            Node::Leaf { label } => *label,
        }
    }
}

Implement Helper Functions

Next, we'll implement helper functions for splitting the data and finding the best split.

// Define a structure for dataset
type Dataset = Vec<(Vec<f64>, usize)>;
 
// Function to calculate Gini impurity
fn gini_impurity(labels: &[usize]) -> f64 {
    let total = labels.len() as f64;
    let mut counts = HashMap::new();
    
    for &label in labels {
        *counts.entry(label).or_insert(0) += 1;
    }
    
    let impurity = counts.values().map(|&count| {
        let prob = count as f64 / total;
        prob * (1.0 - prob)
    }).sum();
    
    impurity
}
 
// Function to find the best split
fn best_split(dataset: &Dataset) -> (usize, f64, f64) {
    let mut best_feature = 0;
    let mut best_value = 0.0;
    let mut best_gini = f64::MAX;
    
    let num_features = dataset[0].0.len();
    
    for feature in 0..num_features {
        let mut values: Vec<f64> = dataset.iter().map(|(features, _)| features[feature]).collect();
        values.sort_by(|a, b| a.partial_cmp(b).unwrap());
        values.dedup();
        
        for &value in &values {
            let (left, right): (Vec<_>, Vec<_>) = dataset.iter().partition(|(features, _)| features[feature] <= value);
            
            let left_labels: Vec<_> = left.iter().map(|(_, label)| *label).collect();
            let right_labels: Vec<_> = right.iter().map(|(_, label)| *label).collect();
            
            let left_impurity = gini_impurity(&left_labels);
            let right_impurity = gini_impurity(&right_labels);
            
            let left_weight = left.len() as f64 / dataset.len() as f64;
            let right_weight = right.len() as f64 / dataset.len() as f64;
            
            let weighted_impurity = left_weight * left_impurity + right_weight * right_impurity;
            
            if weighted_impurity < best_gini {
                best_gini = weighted_impurity;
                best_feature = feature;
                best_value = value;
            }
        }
    }
    
    (best_feature, best_value, best_gini)
}

Build the Tree

Now, implement the function to build the decision tree:

impl DecisionTree {
    // Recursive function to build the tree
    fn build_tree(dataset: &Dataset, max_depth: usize, depth: usize) -> Node {
        if dataset.is_empty() {
            return Node::Leaf { label: 0 }; // Default label
        }
        
        let labels: Vec<_> = dataset.iter().map(|(_, label)| *label).collect();
        let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
        
        if unique_labels.len() == 1 {
            return Node::Leaf { label: *unique_labels.iter().next().unwrap() };
        }
        
        if depth >= max_depth {
            let majority_label = labels.iter().copied().max_by_key(|&label| labels.iter().filter(|&&l| l == label).count()).unwrap();
            return Node::Leaf { label: majority_label };
        }
        
        let (feature, value, _) = best_split(dataset);
        let (left, right): (Vec<_>, Vec<_>) = dataset.iter().partition(|(features, _)| features[feature] <= value);
        
        Node::Internal {
            feature,
            value,
            left: Box::new(Self::build_tree(&left, max_depth, depth + 1)),
            right: Box::new(Self::build_tree(&right, max_depth, depth + 1)),
        }
    }
}

Testing the Implementation

Finally, let's add some tests to ensure our decision tree implementation works correctly. Add the following test cases to src/main.rs:

#[cfg(test)]
mod tests {
    use super::*;
 
    #[test]
    fn test_decision_tree() {
        let dataset = vec![
            (vec![1.0, 2.0], 0),
            (vec![1.5, 1.5], 0),
            (vec![2.0, 2.5], 1),
            (vec![2.5, 1.0], 1),
        ];
        
        let tree = DecisionTree {
            root: DecisionTree::build_tree(&dataset, 2, 0),
        };
        
        assert_eq!(tree.predict(&[1.0, 2.0]), 0);
        assert_eq!(tree.predict(&[2.0, 2.5]), 1);
    }
}

Conclusion

In this blog, we've implemented a basic decision tree algorithm in Rust. We covered defining data structures, implementing the core algorithms, and testing the implementation. Decision trees are versatile and can be extended further for more complex scenarios. Explore and experiment with more advanced techniques and libraries to enhance your machine learning projects in Rust!

Feel free to modify and extend this implementation based on your needs. Happy coding!

#rust#machine-learning

Share on:

Recommended

Copyright © EvolveDev. 2025 All Rights Reserved