Moved input_dim parameter to class constructor.

This commit is contained in:
Andru Liu
2024-10-07 19:57:40 -07:00
parent 9ac342f003
commit cc6ff7ca80
+8 -4
View File
@@ -17,7 +17,7 @@ class RBFLayer(tf.keras.layers.Layer):
call(inputs): call(inputs):
TF call method to implement forward pass. TF call method to implement forward pass.
""" """
def __init__(self, n_centers, input_dim=3): def __init__(self, n_centers, input_dim):
""" Constructs RBF centers and standard deviations as trainable weights. """ Constructs RBF centers and standard deviations as trainable weights.
Parameters Parameters
@@ -25,7 +25,7 @@ class RBFLayer(tf.keras.layers.Layer):
n_centers : int n_centers : int
The number of RBF centers. The number of RBF centers.
input_dim : int input_dim : int
The dimensions of the RBF centers. Default of 3 for Kp, Ki, Kd The dimensions of the RBF centers.
""" """
super().__init__() super().__init__()
self.n_centers = n_centers self.n_centers = n_centers
@@ -66,22 +66,26 @@ class RBFAdaptiveModel(tf.keras.Model):
---------- ----------
n_centers : int n_centers : int
The number of RBF centers. The number of RBF centers.
imput_dim : int
The dimensions of the RBF centers.
Methods Methods
------- -------
call(inputs): call(inputs):
TF call method to implement forward pass of model. TF call method to implement forward pass of model.
""" """
def __init__(self, n_centers): def __init__(self, n_centers, input_dim=3):
""" Constructs RBF and output layers. """ Constructs RBF and output layers.
Parameters Parameters
---------- ----------
n_centers : int n_centers : int
The number of RBF centers. The number of RBF centers.
input_dim : int
The dimensions of the RBF centers. Default of 3 for Kp, Ki, Kd
""" """
super().__init__() super().__init__()
self.rbf_layer = RBFLayer(n_centers) self.rbf_layer = RBFLayer(n_centers, input_dim)
self.output_layer = tf.keras.layers.Dense(1) self.output_layer = tf.keras.layers.Dense(1)
def call(self, inputs): def call(self, inputs):