diff --git a/RBF_tf.py b/RBF_tf.py index fb6d95c..6557def 100644 --- a/RBF_tf.py +++ b/RBF_tf.py @@ -53,3 +53,49 @@ class RBFLayer(tf.keras.layers.Layer): rbf_output = tf.exp(-tf.square(distances) / (2 * tf.square(self.sigmas))) return rbf_output + +class RBFAdaptiveModel(tf.keras.Model): + """ RBF Adaptive Model using TF Subclassing API. + + Outputs one control signal adaptation. Determined by output layer + number of neurons. + + ... + + Attributes + ---------- + n_centers : int + The number of RBF centers. + + Methods + ------- + call(inputs): + TF call method to implement forward pass of model. + """ + def __init__(self, n_centers): + """ Constructs RBF and output layers. + + Parameters + ---------- + n_centers : int + The number of RBF centers. + """ + super().__init__() + self.rbf_layer = RBFLayer(n_centers) + self.output_layer = tf.keras.layers.Dense(1) + + def call(self, inputs): + """ Implement forward pass. + + Parameters + ---------- + inputs : tensor + The points in space to adapt with. + + Returns + ------- + Adapted control signal. + """ + rbf_output = self.rbf_layer(inputs) + control_signal = self.output_layer(rbf_output) + return control_signal