case class BoltzmannSplitter(temperature: Double, rng: Random = Random) extends Splitter[Double] with Product with Serializable
Find a split for a regression problem
The splits are picked with a probability that is related to the reduction in variance: P(split) ~ exp[ - {remaining variance} / ({temperature} * {total variance}) ] recalling that the "variance" here is weighted by the sample size (so its really the sum of the square difference from the mean of that side of the split). This is analogous to simulated annealing and Metropolis-Hastings.
The motivation here is to reduce the correlation of the trees by making random choices between splits that are almost just as good as the strictly optimal one. Reducing the correlation between trees will reduce the variance in an ensemble method (e.g. random forests): the variance will both decrease more quickly with the tree count and will reach a lower floor. In this paragraph, we're using "variance" as in "bias-variance trade-off".
Division by the local total variance make the splitting behavior invariant to data size and the scale of the labels. That means, however, that you can't set the temperature based on a known absolute noise scale. For that, you'd want to divide by the total weight rather than the total variance.
TODO: allow the rescaling to happen based on the total weight instead of the total variance, as an option
Created by maxhutch on 11/29/16.
- temperature
used to control how sensitive the probability of a split is to its change in variance. The temperature can be thought of as a hyperparameter.
- Alphabetic
- By Inheritance
- BoltzmannSplitter
- Serializable
- Product
- Equals
- Splitter
- AnyRef
- Any
- Hide All
- Show All
- Public
- Protected
Instance Constructors
- new BoltzmannSplitter(temperature: Double, rng: Random = Random)
- temperature
used to control how sensitive the probability of a split is to its change in variance. The temperature can be thought of as a hyperparameter.
Value Members
- final def !=(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- final def ##: Int
- Definition Classes
- AnyRef → Any
- final def ==(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- final def asInstanceOf[T0]: T0
- Definition Classes
- Any
- def clone(): AnyRef
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.CloneNotSupportedException]) @native()
- final def eq(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
- def finalize(): Unit
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.Throwable])
- def getBestSplit(data: Seq[(Vector[AnyVal], Double, Double)], numFeatures: Int, minInstances: Int): (Split, Double)
Get the a split probabalisticly, considering numFeature random features (w/o replacement), ensuring that the resulting partitions have at least minInstances in them
Get the a split probabalisticly, considering numFeature random features (w/o replacement), ensuring that the resulting partitions have at least minInstances in them
- data
to split
- numFeatures
to consider, randomly
- minInstances
minimum instances permitted in a post-split partition
- returns
a split object that optimally divides data
- Definition Classes
- BoltzmannSplitter → Splitter
- final def getClass(): Class[_ <: AnyRef]
- Definition Classes
- AnyRef → Any
- Annotations
- @native()
- final def isInstanceOf[T0]: Boolean
- Definition Classes
- Any
- final def ne(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
- final def notify(): Unit
- Definition Classes
- AnyRef
- Annotations
- @native()
- final def notifyAll(): Unit
- Definition Classes
- AnyRef
- Annotations
- @native()
- def productElementNames: Iterator[String]
- Definition Classes
- Product
- val rng: Random
- final def synchronized[T0](arg0: => T0): T0
- Definition Classes
- AnyRef
- val temperature: Double
- final def wait(): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException])
- final def wait(arg0: Long, arg1: Int): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException])
- final def wait(arg0: Long): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException]) @native()