Skip to content

Commit

Permalink
more complete hyper-params
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Dec 22, 2024
1 parent 1f3bb73 commit 7e9fbed
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 13 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# ConstantChangeLog

# unreleased
- more complete hyperparameter for classification

# 7.4.3
updated to fastmath 3

Expand Down
33 changes: 26 additions & 7 deletions src/scicloj/ml/smile/classification.clj
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@
:type :int32
:default 1}]
:gridsearch-options {:trees (ml-gs/linear 2 50 10 :int64)
:max-nodes (ml-gs/linear 4 1000 20 :int64)}
:max-depth (ml-gs/linear 50 500 100 :int64)
:max-nodes (ml-gs/linear 4 1000 20 :int64)
:node-size (ml-gs/linear 1 10 10 :int64)
}
:property-name-stem "smile.adaboost"
:constructor #(AdaBoost/fit ^Formula %1 ^DataFrame %2 ^Properties %3)
:predictor tuple-predict-posterior}
Expand Down Expand Up @@ -163,9 +166,9 @@
:lookup-table split-rule-lookup-table
:default :gini
:description "the splitting rule"}]
:gridsearch-options {:max-nodes (ml-gs/linear 10 1000 30)
:node-size (ml-gs/linear 1 20 20)
:max-depth (ml-gs/linear 1 50 20)
:gridsearch-options {:max-nodes (ml-gs/linear 10 1000 30 :int32)
:node-size (ml-gs/linear 1 20 20 :int32)
:max-depth (ml-gs/linear 1 50 20 :int32)
:split-rule (ml-gs/categorical [:gini :entropy :classification-error])}


Expand Down Expand Up @@ -219,9 +222,16 @@
:type :float64
:default 0.7
:description "the sampling fraction for stochastic tree boosting"}]
:property-name-stem "smile.gbt"
:constructor #(GradientTreeBoost/fit ^Formula %1 ^DataFrame %2 ^Properties %3)
:predictor tuple-predict-posterior}
:gridsearch-options
{:ntrees (ml-gs/linear 10 1000 100 :int32)
:max-depth (ml-gs/linear 10 100 100 :int32)
:max-nodes (ml-gs/linear 10 100 100 :int32)
:node-size (ml-gs/linear 1 100 100 :int32)
:shrinkage (ml-gs/linear 0.01 1 100 :float64)}

:property-name-stem "smile.gbt"
:constructor #(GradientTreeBoost/fit ^Formula %1 ^DataFrame %2 ^Properties %3)
:predictor tuple-predict-posterior}

:knn {:class KNN
:name :knn
Expand Down Expand Up @@ -360,6 +370,15 @@
{:name :class-weight :type :string :default nil
:description "Priors of the classes. The weight of each class is roughly the ratio of samples in each class. For example, if there are 400 positive samples and 100 negative samples, the classWeight should be [1, 4] (assuming label 0 is of negative, label 1 is of positive)"}]

:gridsearch-options
{:trees (ml-gs/linear 10 1000 100 :int32)
:max-depth (ml-gs/linear 10 100 100 :int32)
:max-nodes (ml-gs/linear 10 100 100 :int32)
:node-size (ml-gs/linear 1 100 100 :int32)
:sample-rate (ml-gs/linear 0.1 1.0 100)
:split-rule (ml-gs/categorical [:gini
:entropy
:classification-error])}
:property-name-stem "smile.random.forest"}})
;; fix when this is released:
;; https://github.com/haifengl/smile/blob/2352cff6880056eb9a03dbe2556acdbd8f07ddda/core/src/main/java/smile/regression/RBFNetwork.java#L165
Expand Down
24 changes: 18 additions & 6 deletions test/scicloj/ml/smile/smile_ml_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
[tech.v3.dataset :as ds]
[tech.v3.dataset.modelling :as ds-mod]
[tech.v3.dataset.utils :as ds-utils]
[tech.v3.datatype :as dtype]
[tech.v3.dataset.column-filters :as cf]
[scicloj.ml.smile.malli :as malli]
;; [tablecloth.api :as]
[clojure.test :refer [deftest is]]))
[clojure.test :refer [deftest is]]
[scicloj.metamorph.ml.gridsearch :as ml-gs]))


;;shut that shit up.
Expand All @@ -37,10 +36,23 @@
:smile.classification/svm
:smile.classification/discrete-naive-bayes
:smile.classification/sparse-logistic-regression})))




(defn- one-gs-option [model-type]
(let [options
(->>
(ml/hyperparameters model-type)
(ml-gs/sobol-gridsearch)
(take 1)
first)]
(assoc options
:model-type model-type)))



(deftest smile-classification-hyperparameters-test
(doseq [classify-model smile-classification-models]
;(println :classify-model classify-model)
(verify/basic-classification (one-gs-option classify-model))))


(deftest smile-classification-test
Expand Down

0 comments on commit 7e9fbed

Please sign in to comment.