diff --git a/RBF_tf.py b/RBF_tf.py index 63615b9..432df9b 100644 --- a/RBF_tf.py +++ b/RBF_tf.py @@ -17,7 +17,7 @@ class RBFLayer(tf.keras.layers.Layer): call(inputs): TF call method to implement forward pass. """ - def __init__(self, n_centers, input_dim=3): + def __init__(self, n_centers, input_dim): """ Constructs RBF centers and standard deviations as trainable weights. Parameters @@ -25,7 +25,7 @@ class RBFLayer(tf.keras.layers.Layer): n_centers : int The number of RBF centers. input_dim : int - The dimensions of the RBF centers. Default of 3 for Kp, Ki, Kd + The dimensions of the RBF centers. """ super().__init__() self.n_centers = n_centers @@ -66,22 +66,26 @@ class RBFAdaptiveModel(tf.keras.Model): ---------- n_centers : int The number of RBF centers. + imput_dim : int + The dimensions of the RBF centers. Methods ------- call(inputs): TF call method to implement forward pass of model. """ - def __init__(self, n_centers): + def __init__(self, n_centers, input_dim=3): """ Constructs RBF and output layers. Parameters ---------- n_centers : int The number of RBF centers. + input_dim : int + The dimensions of the RBF centers. Default of 3 for Kp, Ki, Kd """ super().__init__() - self.rbf_layer = RBFLayer(n_centers) + self.rbf_layer = RBFLayer(n_centers, input_dim) self.output_layer = tf.keras.layers.Dense(1) def call(self, inputs):