From cc6ff7ca80d3f66aaed5e17dc4006463017c4892 Mon Sep 17 00:00:00 2001 From: Andru Liu <90433630+WallabyLester@users.noreply.github.com> Date: Mon, 7 Oct 2024 19:57:40 -0700 Subject: [PATCH] Moved input_dim parameter to class constructor. --- RBF_tf.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/RBF_tf.py b/RBF_tf.py index 63615b9..432df9b 100644 --- a/RBF_tf.py +++ b/RBF_tf.py @@ -17,7 +17,7 @@ class RBFLayer(tf.keras.layers.Layer): call(inputs): TF call method to implement forward pass. """ - def __init__(self, n_centers, input_dim=3): + def __init__(self, n_centers, input_dim): """ Constructs RBF centers and standard deviations as trainable weights. Parameters @@ -25,7 +25,7 @@ class RBFLayer(tf.keras.layers.Layer): n_centers : int The number of RBF centers. input_dim : int - The dimensions of the RBF centers. Default of 3 for Kp, Ki, Kd + The dimensions of the RBF centers. """ super().__init__() self.n_centers = n_centers @@ -66,22 +66,26 @@ class RBFAdaptiveModel(tf.keras.Model): ---------- n_centers : int The number of RBF centers. + imput_dim : int + The dimensions of the RBF centers. Methods ------- call(inputs): TF call method to implement forward pass of model. """ - def __init__(self, n_centers): + def __init__(self, n_centers, input_dim=3): """ Constructs RBF and output layers. Parameters ---------- n_centers : int The number of RBF centers. + input_dim : int + The dimensions of the RBF centers. Default of 3 for Kp, Ki, Kd """ super().__init__() - self.rbf_layer = RBFLayer(n_centers) + self.rbf_layer = RBFLayer(n_centers, input_dim) self.output_layer = tf.keras.layers.Dense(1) def call(self, inputs):