Compare commits

...

5 Commits

Author SHA1 Message Date
yoshoku 0876594a42 Add regressor class with variable-random trees 2019-12-16 21:39:42 +09:00
yoshoku 5840324a2a Add classifier class with variable-random trees 2019-12-16 21:39:42 +09:00
yoshoku cab2a8ac29 Add regressor class with variable-random tree 2019-12-16 21:39:42 +09:00
yoshoku 747c338568 Add classifier class with variable-random tree 2019-12-16 21:39:42 +09:00
yoshoku 3028e788bf 🎨 Refactor decision tree abstract class 2019-12-16 21:39:42 +09:00
10 changed files with 999 additions and 2 deletions

View File

@ -49,6 +49,8 @@ require 'rumale/tree/decision_tree_classifier'
require 'rumale/tree/decision_tree_regressor'
require 'rumale/tree/extra_tree_classifier'
require 'rumale/tree/extra_tree_regressor'
require 'rumale/tree/variable_random_tree_classifier'
require 'rumale/tree/variable_random_tree_regressor'
require 'rumale/tree/gradient_tree_regressor'
require 'rumale/ensemble/ada_boost_classifier'
require 'rumale/ensemble/ada_boost_regressor'
@ -58,6 +60,8 @@ require 'rumale/ensemble/random_forest_classifier'
require 'rumale/ensemble/random_forest_regressor'
require 'rumale/ensemble/extra_trees_classifier'
require 'rumale/ensemble/extra_trees_regressor'
require 'rumale/ensemble/variable_random_trees_classifier'
require 'rumale/ensemble/variable_random_trees_regressor'
require 'rumale/clustering/k_means'
require 'rumale/clustering/k_medoids'
require 'rumale/clustering/gaussian_mixture'

View File

@ -0,0 +1,138 @@
# frozen_string_literal: true
require 'rumale/tree/variable_random_tree_classifier'
require 'rumale/ensemble/random_forest_classifier'
module Rumale
module Ensemble
# VariableRandomTreesClassifier is a class that implements variable random trees for classification.
#
# @example
# estimator =
# Rumale::Ensemble::VariableRandomTreesClassifier.new(
# n_estimators: 10, criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
# estimator.fit(training_samples, traininig_labels)
# results = estimator.predict(testing_samples)
#
# *Reference*
# - F. T. Liu, K. M. Ting, Y. Yu, and Z-H. Zhou, "Spectrum of Variable-Random Trees," Journal of Artificial Intelligence Research, vol. 32, pp. 355--384, 2008.
class VariableRandomTreesClassifier < RandomForestClassifier
# Return the set of estimators.
# @return [Array<VariableRandomTreeClassifier>]
attr_reader :estimators
# Return the class labels.
# @return [Numo::Int32] (size: n_classes)
attr_reader :classes
# Return the importance for each feature.
# @return [Numo::DFloat] (size: n_features)
attr_reader :feature_importances
# Return the random generator for random selection of feature index.
# @return [Random]
attr_reader :rng
# Create a new classifier with variable-random trees.
#
# @param n_estimators [Integer] The numeber of trees for contructing extremely randomized trees.
# @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
# @param max_depth [Integer] The maximum depth of the tree.
# If nil is given, variable-random tree grows without concern for depth.
# @param max_leaf_nodes [Integer] The maximum number of leaves on variable-random tree.
# If nil is given, number of leaves is not limited.
# @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
# @param max_features [Integer] The number of features to consider when searching optimal split point.
# If nil is given, split process considers 'Math.sqrt(n_features)' features.
# @param n_jobs [Integer] The number of jobs for running the fit method in parallel.
# If nil is given, the method does not execute in parallel.
# If zero or less is given, it becomes equal to the number of processors.
# This parameter is ignored if the Parallel gem is not loaded.
# @param random_seed [Integer] The seed value using to initialize the random generator.
# It is used to randomly determine the order of features when deciding spliting point.
def initialize(n_estimators: 10,
criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1,
max_features: nil, n_jobs: nil, random_seed: nil)
check_params_numeric_or_nil(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
max_features: max_features, n_jobs: n_jobs, random_seed: random_seed)
check_params_numeric(n_estimators: n_estimators, min_samples_leaf: min_samples_leaf)
check_params_string(criterion: criterion)
check_params_positive(n_estimators: n_estimators, max_depth: max_depth,
max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf,
max_features: max_features)
super
end
# Fit the model with given training data.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
# @return [VariableRandomTreesClassifier] The learned classifier itself.
def fit(x, y)
x = check_convert_sample_array(x)
y = check_convert_label_array(y)
check_sample_label_size(x, y)
# Initialize some variables.
n_features = x.shape[1]
@params[:max_features] = Math.sqrt(n_features).to_i if @params[:max_features].nil?
@params[:max_features] = [[1, @params[:max_features]].max, n_features].min
@classes = Numo::Int32.asarray(y.to_a.uniq.sort)
sub_rng = @rng.dup
# Construct trees.
alpha_step = 0.5 / @params[:n_estimators]
alpha_vals = Array.new(@params[:n_estimators]) { |n| alpha_step * n }
rng_seeds = Array.new(@params[:n_estimators]) { sub_rng.rand(Rumale::Values.int_max) }
@estimators = if enable_parallel?
parallel_map(@params[:n_estimators]) { |n| plant_tree(alpha_vals[n], rng_seeds[n]).fit(x, y) }
else
Array.new(@params[:n_estimators]) { |n| plant_tree(alpha_vals[n], rng_seeds[n]).fit(x, y) }
end
@feature_importances =
if enable_parallel?
parallel_map(@params[:n_estimators]) { |n| @estimators[n].feature_importances }.reduce(&:+)
else
@estimators.map(&:feature_importances).reduce(&:+)
end
@feature_importances /= @feature_importances.sum
self
end
# Predict class labels for samples.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
def predict(x)
x = check_convert_sample_array(x)
super
end
# Predict probability for samples.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
def predict_proba(x)
x = check_convert_sample_array(x)
super
end
# Return the index of the leaf that each sample reached.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
# @return [Numo::Int32] (shape: [n_samples, n_estimators]) Leaf index for sample.
def apply(x)
x = check_convert_sample_array(x)
super
end
private
def plant_tree(alpha, rnd_seed)
Tree::VariableRandomTreeClassifier.new(
criterion: @params[:criterion], max_depth: @params[:max_depth],
max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
max_features: @params[:max_features], alpha: alpha, random_seed: rnd_seed
)
end
end
end
end

View File

@ -0,0 +1,124 @@
# frozen_string_literal: true
require 'rumale/tree/variable_random_tree_regressor'
require 'rumale/ensemble/random_forest_regressor'
module Rumale
module Ensemble
# VariableRandomTreesRegressor is a class that implements variable-random trees for regression
#
# @example
# estimator =
# Rumale::Ensemble::VariableRandomTreesRegressor.new(
# n_estimators: 10, criterion: 'mse', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
# estimator.fit(training_samples, traininig_values)
# results = estimator.predict(testing_samples)
#
# *Reference*
# - F. T. Liu, K. M. Ting, Y. Yu, and Z-H. Zhou, "Spectrum of Variable-Random Trees," Journal of Artificial Intelligence Research, vol. 32, pp. 355--384, 2008.
class VariableRandomTreesRegressor < RandomForestRegressor
# Return the set of estimators.
# @return [Array<VariableRandomTreeRegressor>]
attr_reader :estimators
# Return the importance for each feature.
# @return [Numo::DFloat] (size: n_features)
attr_reader :feature_importances
# Return the random generator for random selection of feature index.
# @return [Random]
attr_reader :rng
# Create a new regressor with extremely randomized trees.
#
# @param n_estimators [Integer] The numeber of trees for contructing extremely randomized trees.
# @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
# @param max_depth [Integer] The maximum depth of the tree.
# If nil is given, variable-random tree grows without concern for depth.
# @param max_leaf_nodes [Integer] The maximum number of leaves on variable-random tree.
# If nil is given, number of leaves is not limited.
# @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
# @param max_features [Integer] The number of features to consider when searching optimal split point.
# If nil is given, split process considers 'Math.sqrt(n_features)' features.
# @param n_jobs [Integer] The number of jobs for running the fit and predict methods in parallel.
# If nil is given, the methods do not execute in parallel.
# If zero or less is given, it becomes equal to the number of processors.
# This parameter is ignored if the Parallel gem is not loaded.
# @param random_seed [Integer] The seed value using to initialize the random generator.
# It is used to randomly determine the order of features when deciding spliting point.
def initialize(n_estimators: 10,
criterion: 'mse', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1,
max_features: nil, n_jobs: nil, random_seed: nil)
check_params_numeric_or_nil(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
max_features: max_features, n_jobs: n_jobs, random_seed: random_seed)
check_params_numeric(n_estimators: n_estimators, min_samples_leaf: min_samples_leaf)
check_params_string(criterion: criterion)
check_params_positive(n_estimators: n_estimators, max_depth: max_depth,
max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf,
max_features: max_features)
super
end
# Fit the model with given training data.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
# @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values to be used for fitting the model.
# @return [VariableRandomTreesRegressor] The learned regressor itself.
def fit(x, y)
x = check_convert_sample_array(x)
y = check_convert_tvalue_array(y)
check_sample_tvalue_size(x, y)
# Initialize some variables.
n_features = x.shape[1]
@params[:max_features] = Math.sqrt(n_features).to_i if @params[:max_features].nil?
@params[:max_features] = [[1, @params[:max_features]].max, n_features].min
sub_rng = @rng.dup
# Construct forest.
alpha_step = 0.5 / @params[:n_estimators]
alpha_vals = Array.new(@params[:n_estimators]) { |n| alpha_step * n }
rng_seeds = Array.new(@params[:n_estimators]) { sub_rng.rand(Rumale::Values.int_max) }
@estimators = if enable_parallel?
parallel_map(@params[:n_estimators]) { |n| plant_tree(alpha_vals[n], rng_seeds[n]).fit(x, y) }
else
Array.new(@params[:n_estimators]) { |n| plant_tree(alpha_vals[n], rng_seeds[n]).fit(x, y) }
end
@feature_importances =
if enable_parallel?
parallel_map(@params[:n_estimators]) { |n| @estimators[n].feature_importances }.reduce(&:+)
else
@estimators.map(&:feature_importances).reduce(&:+)
end
@feature_importances /= @feature_importances.sum
self
end
# Predict values for samples.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
# @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted value per sample.
def predict(x)
x = check_convert_sample_array(x)
super
end
# Return the index of the leaf that each sample reached.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to assign each leaf.
# @return [Numo::Int32] (shape: [n_samples, n_estimators]) Leaf index for sample.
def apply(x)
x = check_convert_sample_array(x)
super
end
private
def plant_tree(alpha, rnd_seed)
Tree::VariableRandomTreeRegressor.new(
criterion: @params[:criterion], max_depth: @params[:max_depth],
max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
max_features: @params[:max_features], alpha: alpha, random_seed: rnd_seed
)
end
end
end
end

View File

@ -88,8 +88,7 @@ module Rumale
return put_leaf(node, y) if stop_growing?(y)
# calculate optimal parameters.
feature_id, left_imp, right_imp, threshold, gain =
rand_ids.map { |n| [n, *best_split(x[true, n], y, impurity)] }.max_by(&:last)
feature_id, left_imp, right_imp, threshold, gain = find_best_split(x, y, impurity)
return put_leaf(node, y) if gain.nil? || gain.zero?
@ -118,6 +117,10 @@ module Rumale
@feature_ids.sample(@params[:max_features], random: @sub_rng)
end
def find_best_split(x, y, impurity)
rand_ids.map { |n| [n, *best_split(x[true, n], y, impurity)] }.max_by(&:last)
end
def best_split(_features, _y, _impurity)
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
end

View File

@ -0,0 +1,123 @@
# frozen_string_literal: true
require 'rumale/tree/decision_tree_classifier'
module Rumale
module Tree
# VariableRandomTreeClassifier is a class that implements variable-random tree for classification.
#
# @example
# estimator =
# Rumale::Tree::VariableRandomTreeClassifier.new(
# criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5,
# alpha: 0.3, random_seed: 1
# )
# estimator.fit(training_samples, traininig_labels)
# results = estimator.predict(testing_samples)
#
# *Reference*
# - F. T. Liu, K. M. Ting, Y. Yu, and Z-H. Zhou, "Spectrum of Variable-Random Trees," Journal of Artificial Intelligence Research, vol. 32, pp. 355--384, 2008.
class VariableRandomTreeClassifier < DecisionTreeClassifier
# Return the class labels.
# @return [Numo::Int32] (size: n_classes)
attr_reader :classes
# Return the importance for each feature.
# @return [Numo::DFloat] (size: n_features)
attr_reader :feature_importances
# Return the learned tree.
# @return [Node]
attr_reader :tree
# Return the random generator for random selection of feature index.
# @return [Random]
attr_reader :rng
# Return the labels assigned each leaf.
# @return [Numo::Int32] (size: n_leafs)
attr_reader :leaf_labels
# Create a new classifier with variable-random tree.
#
# @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
# @param max_depth [Integer] The maximum depth of the tree.
# If nil is given, variable-random tree grows without concern for depth.
# @param max_leaf_nodes [Integer] The maximum number of leaves on variable-random tree.
# If nil is given, number of leaves is not limited.
# @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
# @param max_features [Integer] The number of features to consider when searching optimal split point.
# If nil is given, split process considers all features.
# @param alpha [Float] The probability of choosing deterministic test-selection.
# @param random_seed [Integer] The seed value using to initialize the random generator.
# It is used to randomly determine the order of features when deciding spliting point.
def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
alpha: 0.3, random_seed: nil)
check_params_numeric_or_nil(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
max_features: max_features, random_seed: random_seed)
check_params_numeric(min_samples_leaf: min_samples_leaf, alpha: alpha)
check_params_string(criterion: criterion)
check_params_positive(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
min_samples_leaf: min_samples_leaf, max_features: max_features)
keywd_args = method(:initialize).parameters.map { |_t, arg| [arg, binding.local_variable_get(arg)] }.to_h
keywd_args.delete(:alpha)
super(keywd_args)
@params[:alpha] = alpha
end
# Fit the model with given training data.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
# @return [VariableRandomTreeClassifier] The learned classifier itself.
def fit(x, y)
x = check_convert_sample_array(x)
y = check_convert_label_array(y)
check_sample_label_size(x, y)
super
end
# Predict class labels for samples.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
def predict(x)
x = check_convert_sample_array(x)
super
end
# Predict probability for samples.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
def predict_proba(x)
x = check_convert_sample_array(x)
super
end
private
def find_best_split(x, y, impurity)
if @sub_rng.rand <= @params[:alpha]
rand_ids.map { |n| [n, *best_split(x[true, n], y, impurity)] }.max_by(&:last)
else
n = rand_ids.first
[n, *best_split_rand(x[true, n], y, impurity)]
end
end
def best_split_rand(features, y, whole_impurity)
low, hi = Array.new(features.shape[0]) { |n| n }.sample(2, random: @sub_rng)
threshold = 0.5 * (features[low] + features[hi])
l_ids = features.le(threshold).where
r_ids = features.gt(threshold).where
l_impurity = l_ids.empty? ? 0.0 : impurity(y[l_ids, true])
r_impurity = r_ids.empty? ? 0.0 : impurity(y[r_ids, true])
gain = whole_impurity -
l_impurity * l_ids.size.fdiv(y.shape[0]) -
r_impurity * r_ids.size.fdiv(y.shape[0])
[l_impurity, r_impurity, threshold, gain]
end
end
end
end

View File

@ -0,0 +1,110 @@
# frozen_string_literal: true
require 'rumale/tree/decision_tree_regressor'
module Rumale
module Tree
# VariableRandomTreeRegressor is a class that implements variable-random tree for regression.
#
# @example
# estimator =
# Rumale::Tree::VariableRandomTreeRegressor.new(
# max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5,
# alpha: 0.3, random_seed: 1
# )
# estimator.fit(training_samples, traininig_values)
# results = estimator.predict(testing_samples)
#
# *Reference*
# - F. T. Liu, K. M. Ting, Y. Yu, and Z-H. Zhou, "Spectrum of Variable-Random Trees," Journal of Artificial Intelligence Research, vol. 32, pp. 355--384, 2008.
class VariableRandomTreeRegressor < DecisionTreeRegressor
# Return the importance for each feature.
# @return [Numo::DFloat] (size: n_features)
attr_reader :feature_importances
# Return the learned tree.
# @return [Node]
attr_reader :tree
# Return the random generator for random selection of feature index.
# @return [Random]
attr_reader :rng
# Return the values assigned each leaf.
# @return [Numo::DFloat] (shape: [n_leafs, n_outputs])
attr_reader :leaf_values
# Create a new regressor with variable-random tree.
#
# @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
# @param max_depth [Integer] The maximum depth of the tree.
# If nil is given, variable-random tree grows without concern for depth.
# @param max_leaf_nodes [Integer] The maximum number of leaves on variable-random tree.
# If nil is given, number of leaves is not limited.
# @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
# @param max_features [Integer] The number of features to consider when searching optimal split point.
# If nil is given, split process considers all features.
# @param alpha [Float] The probability of choosing deterministic test-selection.
# @param random_seed [Integer] The seed value using to initialize the random generator.
# It is used to randomly determine the order of features when deciding spliting point.
def initialize(criterion: 'mse', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
alpha: 0.3, random_seed: nil)
check_params_numeric_or_nil(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
max_features: max_features, random_seed: random_seed)
check_params_numeric(min_samples_leaf: min_samples_leaf, alpha: alpha)
check_params_string(criterion: criterion)
check_params_positive(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
min_samples_leaf: min_samples_leaf, max_features: max_features)
keywd_args = method(:initialize).parameters.map { |_t, arg| [arg, binding.local_variable_get(arg)] }.to_h
keywd_args.delete(:alpha)
super(keywd_args)
@params[:alpha] = alpha
end
# Fit the model with given training data.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
# @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The taget values to be used for fitting the model.
# @return [VariableRandomTreeRegressor] The learned regressor itself.
def fit(x, y)
x = check_convert_sample_array(x)
y = check_convert_tvalue_array(y)
check_sample_tvalue_size(x, y)
super
end
# Predict values for samples.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
# @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted values per sample.
def predict(x)
x = check_convert_sample_array(x)
super
end
private
def find_best_split(x, y, impurity)
if @sub_rng.rand <= @params[:alpha]
rand_ids.map { |n| [n, *best_split(x[true, n], y, impurity)] }.max_by(&:last)
else
n = rand_ids.first
[n, *best_split_rand(x[true, n], y, impurity)]
end
end
def best_split_rand(features, y, whole_impurity)
low, hi = Array.new(features.shape[0]) { |n| n }.sample(2, random: @sub_rng)
threshold = 0.5 * (features[low] + features[hi])
l_ids = features.le(threshold).where
r_ids = features.gt(threshold).where
l_impurity = l_ids.empty? ? 0.0 : impurity(y[l_ids, true])
r_impurity = r_ids.empty? ? 0.0 : impurity(y[r_ids, true])
gain = whole_impurity -
l_impurity * l_ids.size.fdiv(y.shape[0]) -
r_impurity * r_ids.size.fdiv(y.shape[0])
[l_impurity, r_impurity, threshold, gain]
end
end
end
end

View File

@ -0,0 +1,101 @@
# frozen_string_literal: true
require 'spec_helper'
RSpec.describe Rumale::Ensemble::VariableRandomTreesClassifier do
let(:x) { dataset[0] }
let(:y) { dataset[1] }
let(:classes) { y.to_a.uniq.sort }
let(:n_samples) { x.shape[0] }
let(:n_features) { x.shape[1] }
let(:n_classes) { classes.size }
let(:n_estimators) { 10 }
let(:n_jobs) { nil }
let(:estimator) do
described_class.new(n_estimators: n_estimators, max_depth: 2, max_features: 2, n_jobs: n_jobs, random_seed: 1).fit(x, y)
end
let(:score) { estimator.score(x, y) }
context 'when binary classification problem' do
let(:dataset) { two_clusters_dataset }
it 'classifies two clusters data.', :aggregate_failures do
expect(estimator.params[:n_estimators]).to eq(n_estimators)
expect(estimator.params[:max_depth]).to eq(2)
expect(estimator.params[:max_features]).to eq(2)
expect(estimator.estimators.class).to eq(Array)
expect(estimator.estimators.size).to eq(n_estimators)
expect(estimator.estimators[0].class).to eq(Rumale::Tree::VariableRandomTreeClassifier)
expect(estimator.classes.class).to eq(Numo::Int32)
expect(estimator.classes.ndim).to eq(1)
expect(estimator.classes.shape[0]).to eq(n_classes)
expect(estimator.feature_importances.class).to eq(Numo::DFloat)
expect(estimator.feature_importances.ndim).to eq(1)
expect(estimator.feature_importances.shape[0]).to eq(n_features)
expect(score).to eq(1.0)
end
end
context 'when multiclass classification problem' do
let(:dataset) { three_clusters_dataset }
let(:probs) { estimator.predict_proba(x) }
let(:predicted_by_probs) { Numo::Int32[*(Array.new(n_samples) { |n| classes[probs[n, true].max_index] })] }
let(:index_mat) { estimator.apply(x) }
let(:copied) { Marshal.load(Marshal.dump(estimator)) }
it 'classifies three clusters data.', :aggregate_failures do
expect(estimator.estimators.class).to eq(Array)
expect(estimator.estimators.size).to eq(n_estimators)
expect(estimator.estimators[0].class).to eq(Rumale::Tree::VariableRandomTreeClassifier)
expect(estimator.classes.class).to eq(Numo::Int32)
expect(estimator.classes.ndim).to eq(1)
expect(estimator.classes.shape[0]).to eq(n_classes)
expect(estimator.feature_importances.class).to eq(Numo::DFloat)
expect(estimator.feature_importances.ndim).to eq(1)
expect(estimator.feature_importances.shape[0]).to eq(n_features)
expect(score).to eq(1.0)
end
it 'estimates class probabilities with three clusters dataset.', :aggregate_failures do
expect(probs.class).to eq(Numo::DFloat)
expect(probs.ndim).to eq(2)
expect(probs.shape[0]).to eq(n_samples)
expect(probs.shape[1]).to eq(n_classes)
expect(predicted_by_probs).to eq(y)
end
it 'returns leaf index that each sample reached', :aggregate_failures do
expect(index_mat.ndim).to eq(2)
expect(index_mat.shape[0]).to eq(n_samples)
expect(index_mat.shape[1]).to eq(n_estimators)
expect(index_mat[true, 0]).to eq(estimator.estimators[0].apply(x))
end
it 'dumps and restores itself using Marshal module.', :aggregate_failures do
expect(estimator.class).to eq(copied.class)
expect(estimator.params).to eq(copied.params)
expect(estimator.estimators.size).to eq(copied.estimators.size)
expect(estimator.classes).to eq(copied.classes)
expect(estimator.feature_importances).to eq(copied.feature_importances)
expect(estimator.rng).to eq(copied.rng)
expect(score).to eq(copied.score(x, y))
end
context 'when n_jobs parameter is not nil' do
let(:n_jobs) { -1 }
it 'classifies three clusters data in parallel.', :aggregate_failures do
expect(estimator.estimators.class).to eq(Array)
expect(estimator.estimators.size).to eq(n_estimators)
expect(estimator.estimators[0].class).to eq(Rumale::Tree::VariableRandomTreeClassifier)
expect(estimator.classes.class).to eq(Numo::Int32)
expect(estimator.classes.ndim).to eq(1)
expect(estimator.classes.shape[0]).to eq(n_classes)
expect(estimator.feature_importances.class).to eq(Numo::DFloat)
expect(estimator.feature_importances.ndim).to eq(1)
expect(estimator.feature_importances.shape[0]).to eq(n_features)
expect(score).to eq(1.0)
end
end
end
end

View File

@ -0,0 +1,91 @@
# frozen_string_literal: true
require 'spec_helper'
RSpec.describe Rumale::Ensemble::VariableRandomTreesRegressor do
let(:x) { two_clusters_dataset[0] }
let(:n_samples) { x.shape[0] }
let(:n_features) { x.shape[1] }
let(:n_estimators) { 10 }
let(:n_jobs) { nil }
let(:estimator) do
described_class.new(n_estimators: n_estimators, criterion: 'mae', max_features: 2, n_jobs: n_jobs, random_seed: 9).fit(x, y)
end
let(:predicted) { estimator.predict(x) }
let(:score) { estimator.score(x, y) }
context 'when single target problem' do
let(:y) { x[true, 0] + x[true, 1]**2 }
let(:index_mat) { estimator.apply(x) }
let(:copied) { Marshal.load(Marshal.dump(estimator)) }
it 'learns the model for single regression problem.', :aggregate_failures do
expect(estimator.params[:n_estimators]).to eq(n_estimators)
expect(estimator.params[:criterion]).to eq('mae')
expect(estimator.params[:max_features]).to eq(2)
expect(estimator.estimators.class).to eq(Array)
expect(estimator.estimators.size).to eq(n_estimators)
expect(estimator.estimators[0].class).to eq(Rumale::Tree::VariableRandomTreeRegressor)
expect(estimator.feature_importances.class).to eq(Numo::DFloat)
expect(estimator.feature_importances.ndim).to eq(1)
expect(estimator.feature_importances.shape[0]).to eq(n_features)
expect(predicted.class).to eq(Numo::DFloat)
expect(predicted.ndim).to eq(1)
expect(predicted.shape[0]).to eq(n_samples)
expect(score).to be_within(0.01).of(1.0)
end
it 'returns leaf index that each sample reached.', :aggregate_failures do
expect(index_mat.ndim).to eq(2)
expect(index_mat.shape[0]).to eq(n_samples)
expect(index_mat.shape[1]).to eq(n_estimators)
expect(index_mat[true, 0]).to eq(estimator.estimators[0].apply(x))
end
it 'dumps and restores itself using Marshal module.', :aggregate_failures do
expect(estimator.class).to eq(copied.class)
expect(estimator.params).to match(copied.params)
expect(estimator.estimators.size).to eq(copied.estimators.size)
expect(estimator.feature_importances).to eq(copied.feature_importances)
expect(estimator.rng).to eq(copied.rng)
expect(score).to eq(copied.score(x, y))
end
end
context 'when multi-target problem' do
let(:y) { Numo::DFloat[x[true, 0].to_a, (x[true, 1]**2).to_a].transpose.dot(Numo::DFloat[[0.6, 0.4], [0.0, 0.1]]) }
let(:n_outputs) { y.shape[1] }
it 'learns the model for multiple regression problem.', :aggregate_failures do
expect(estimator.estimators.class).to eq(Array)
expect(estimator.estimators.size).to eq(n_estimators)
expect(estimator.estimators[0].class).to eq(Rumale::Tree::VariableRandomTreeRegressor)
expect(estimator.feature_importances.class).to eq(Numo::DFloat)
expect(estimator.feature_importances.ndim).to eq(1)
expect(estimator.feature_importances.shape[0]).to eq(n_features)
expect(predicted.class).to eq(Numo::DFloat)
expect(predicted.ndim).to eq(2)
expect(predicted.shape[0]).to eq(n_samples)
expect(predicted.shape[1]).to eq(n_outputs)
expect(score).to be_within(0.01).of(1.0)
end
context 'when n_jobs parameter is not nil' do
let(:n_jobs) { -1 }
it 'learns the model for multiple regression problem in parallel.', :aggregate_failures do
expect(estimator.estimators.class).to eq(Array)
expect(estimator.estimators.size).to eq(n_estimators)
expect(estimator.estimators[0].class).to eq(Rumale::Tree::VariableRandomTreeRegressor)
expect(estimator.feature_importances.class).to eq(Numo::DFloat)
expect(estimator.feature_importances.ndim).to eq(1)
expect(estimator.feature_importances.shape[0]).to eq(n_features)
expect(predicted.class).to eq(Numo::DFloat)
expect(predicted.ndim).to eq(2)
expect(predicted.shape[0]).to eq(n_samples)
expect(predicted.shape[1]).to eq(n_outputs)
expect(score).to be_within(0.01).of(1.0)
end
end
end
end

View File

@ -0,0 +1,152 @@
# frozen_string_literal: true
require 'spec_helper'
RSpec.describe Rumale::Tree::VariableRandomTreeClassifier do
let(:x) { dataset[0] }
let(:y) { dataset[1] }
let(:classes) { y.to_a.uniq.sort }
let(:n_samples) { x.shape[0] }
let(:n_features) { x.shape[1] }
let(:n_classes) { classes.size }
let(:criterion) { 'gini' }
let(:max_depth) { nil }
let(:max_leaf_nodes) { nil }
let(:min_samples_leaf) { 1 }
let(:max_features) { nil }
let(:alpha) { 0.3 }
let(:estimator) do
described_class.new(criterion: criterion, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
min_samples_leaf: min_samples_leaf, max_features: max_features,
alpha: alpha, random_seed: 1).fit(x, y)
end
let(:probs) { estimator.predict_proba(x) }
let(:predicted_by_probs) { Numo::Int32[*(Array.new(n_samples) { |n| classes[probs[n, true].max_index] })] }
let(:score) { estimator.score(x, y) }
let(:copied) { Marshal.load(Marshal.dump(estimator)) }
context 'when binary classification problem' do
let(:dataset) { two_clusters_dataset }
it 'classifies two clusters data.', :aggregate_failures do
expect(estimator.tree.class).to eq(Rumale::Tree::Node)
expect(estimator.classes.class).to eq(Numo::Int32)
expect(estimator.classes.ndim).to eq(1)
expect(estimator.classes.shape[0]).to eq(n_classes)
expect(estimator.feature_importances.class).to eq(Numo::DFloat)
expect(estimator.feature_importances.ndim).to eq(1)
expect(estimator.feature_importances.shape[0]).to eq(n_features)
expect(score).to eq(1.0)
end
end
context 'when multiclass classification problem' do
let(:dataset) { three_clusters_dataset }
it 'classifies three clusters data.', :aggregate_failures do
expect(estimator.tree.class).to eq(Rumale::Tree::Node)
expect(estimator.classes.class).to eq(Numo::Int32)
expect(estimator.classes.ndim).to eq(1)
expect(estimator.classes.shape[0]).to eq(n_classes)
expect(estimator.feature_importances.class).to eq(Numo::DFloat)
expect(estimator.feature_importances.ndim).to eq(1)
expect(estimator.feature_importances.shape[0]).to eq(n_features)
expect(score).to eq(1.0)
end
it 'estimates class probabilities with three clusters dataset.', :aggregate_failures do
expect(probs.class).to eq(Numo::DFloat)
expect(probs.ndim).to eq(2)
expect(probs.shape[0]).to eq(n_samples)
expect(probs.shape[1]).to eq(n_classes)
expect(predicted_by_probs).to eq(y)
end
it 'dumps and restores itself using Marshal module.', :aggregated_failures do
expect(estimator.class).to eq(copied.class)
expect(estimator.classes).to eq(copied.classes)
expect(estimator.feature_importances).to eq(copied.feature_importances)
expect(estimator.rng).to eq(copied.rng)
# FIXME: A slight error on the value of the threhold parameter occurs.
# It seems to be caused by rounding error of Float.
# expect(estimator.tree).to eq(copied.tree)
expect(score).to eq(copied.score(x, y))
end
context 'when max_depth parameter is given' do
let(:max_depth) { 1 }
it 'learns model with given parameters.', :aggregate_failures do
expect(estimator.params[:max_depth]).to eq(max_depth)
expect(estimator.tree.left.left).to be_nil
expect(estimator.tree.left.right).to be_nil
expect(estimator.tree.right.left).to be_nil
expect(estimator.tree.right.right).to be_nil
end
end
context 'when max_leaf_nodes parameter is given' do
let(:max_leaf_nodes) { 2 }
it 'learns model with given parameters.', :aggregate_failures do
expect(estimator.params[:max_leaf_nodes]).to eq(max_leaf_nodes)
expect(estimator.leaf_labels.size).to eq(max_leaf_nodes)
end
end
context 'when min_samples_leaf parameter is given' do
let(:min_samples_leaf) { 150 }
it 'learns model with given parameters.', :aggregate_failures do
expect(estimator.params[:min_samples_leaf]).to eq(min_samples_leaf)
expect(estimator.tree.left.leaf).to be_truthy
expect(estimator.tree.left.n_samples).to be >= min_samples_leaf
expect(estimator.tree.right).to be_nil
end
end
context 'when alpha parameter is given' do
context 'with zero' do
let(:alpha) { 0 }
it 'behaves like a random tree' do
expect(estimator.leaf_labels.size).to be > n_classes
end
end
context 'with one' do
let(:alpha) { 1 }
it 'behaves like a decision tree' do
expect(estimator.leaf_labels.size).to eq(n_classes)
end
end
end
context 'when max_features parameter is given' do
context 'with negative value' do
let(:max_features) { -10 }
it 'raises ArgumentError by validation' do
expect { estimator }.to raise_error(ArgumentError)
end
end
context 'with value larger than number of features' do
let(:max_features) { 10 }
it 'value of max_features is equal to the number of features' do
expect(estimator.params[:max_features]).to eq(x.shape[1])
end
end
context 'with valid value' do
let(:max_features) { 2 }
it 'learns model with given parameters.' do
expect(estimator.params[:max_features]).to eq(2)
end
end
end
end
end

View File

@ -0,0 +1,151 @@
# frozen_string_literal: true
require 'spec_helper'
RSpec.describe Rumale::Tree::VariableRandomTreeRegressor do
let(:x) { two_clusters_dataset[0] }
let(:single_target) { x[true, 0] + x[true, 1]**2 }
let(:multi_target) { Numo::DFloat[x[true, 0].to_a, (x[true, 1]**2).to_a].transpose.dot(Numo::DFloat[[0.6, 0.4], [0.8, 0.2]]) }
let(:n_samples) { x.shape[0] }
let(:n_features) { x.shape[1] }
let(:n_outputs) { multi_target.shape[1] }
let(:criterion) { 'mse' }
let(:max_depth) { nil }
let(:max_leaf_nodes) { nil }
let(:min_samples_leaf) { 1 }
let(:max_features) { nil }
let(:alpha) { 0.3 }
let(:estimator) do
described_class.new(criterion: criterion, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
min_samples_leaf: min_samples_leaf, max_features: max_features,
alpha: alpha, random_seed: 8).fit(x, y)
end
let(:predicted) { estimator.predict(x) }
let(:score) { estimator.score(x, y) }
let(:copied) { Marshal.load(Marshal.dump(estimator)) }
context 'when single target problem' do
let(:y) { single_target }
it 'learns the model for single regression problem.', :aggregate_failures do
expect(estimator.tree.class).to eq(Rumale::Tree::Node)
expect(estimator.feature_importances.class).to eq(Numo::DFloat)
expect(estimator.feature_importances.ndim).to eq(1)
expect(estimator.feature_importances.shape[0]).to eq(n_features)
expect(estimator.leaf_values.class).to eq(Numo::DFloat)
expect(estimator.leaf_values.ndim).to eq(1)
expect(estimator.leaf_values.shape[0]).not_to be_zero
expect(predicted.class).to eq(Numo::DFloat)
expect(predicted.ndim).to eq(1)
expect(predicted.shape[0]).to eq(n_samples)
expect(score).to be_within(0.01).of(1.0)
end
it 'dumps and restores itself using Marshal module.', :aggregate_failures do
expect(estimator.class).to eq(copied.class)
expect(estimator.tree.class).to eq(copied.tree.class)
expect(estimator.feature_importances).to eq(copied.feature_importances)
expect(estimator.rng).to eq(copied.rng)
expect(score).to eq(copied.score(x, y))
end
context 'when max_depth parameter is given' do
let(:max_depth) { 1 }
it 'learns model with given parameters.', :aggregate_failures do
expect(estimator.params[:max_depth]).to eq(max_depth)
expect(estimator.tree.left.left).to be_nil
expect(estimator.tree.left.right).to be_nil
expect(estimator.tree.right.left).to be_nil
expect(estimator.tree.right.right).to be_nil
end
end
context 'when max_leaf_nodes parameter is given' do
let(:max_leaf_nodes) { 2 }
it 'learns model with given parameters.', :aggregate_failrues do
expect(estimator.params[:max_leaf_nodes]).to eq(max_leaf_nodes)
expect(estimator.leaf_values.size).to eq(max_leaf_nodes)
end
end
context 'when min_samples_leaf parameter is given' do
let(:min_samples_leaf) { 150 }
it 'learns model with given parameters.', :aggregate_failures do
expect(estimator.params[:min_samples_leaf]).to eq(min_samples_leaf)
expect(estimator.tree.left.leaf).to be_truthy
expect(estimator.tree.left.n_samples).to be >= min_samples_leaf
expect(estimator.tree.right).to be_nil
end
end
context 'when alpha parameter is given' do
let(:max_depth) { 2 }
context 'with zero' do
let(:alpha) { 0 }
it 'behaves like a random tree' do
expect(score).to be < 0.5
end
end
context 'with one' do
let(:alpha) { 1 }
it 'behaves like a decision tree' do
expect(score).to be > 0.5
end
end
end
context 'when max_features parameter is given' do
context 'with negative value' do
let(:max_features) { -10 }
it 'raises ArgumentError by validation' do
expect { estimator }.to raise_error(ArgumentError)
end
end
context 'with value larger than number of features' do
let(:max_features) { 10 }
it 'value of max_features is equal to the number of features' do
expect(estimator.params[:max_features]).to eq(x.shape[1])
end
end
context 'with valid value' do
let(:max_features) { 2 }
it 'learns model with given parameters.' do
expect(estimator.params[:max_features]).to eq(2)
end
end
end
end
context 'when multi-target problem' do
let(:criterion) { 'mae' }
let(:y) { multi_target }
it 'learns the model for multiple regression problem.', :aggregate_failures do
expect(estimator.tree.class).to eq(Rumale::Tree::Node)
expect(estimator.feature_importances.class).to eq(Numo::DFloat)
expect(estimator.feature_importances.ndim).to eq(1)
expect(estimator.feature_importances.shape[0]).to eq(n_features)
expect(estimator.leaf_values.class).to eq(Numo::DFloat)
expect(estimator.leaf_values.ndim).to eq(2)
expect(estimator.leaf_values.shape[0]).not_to be_zero
expect(estimator.leaf_values.shape[1]).to eq(n_outputs)
expect(predicted.class).to eq(Numo::DFloat)
expect(predicted.ndim).to eq(2)
expect(predicted.shape[0]).to eq(n_samples)
expect(predicted.shape[1]).to eq(n_outputs)
expect(score).to be_within(0.01).of(1.0)
end
end
end