mirror of
https://github.com/WallabyLester/RBF-aPID-Controller.git
synced 2026-05-10 04:58:14 +08:00
Moved input_dim parameter to class constructor.
This commit is contained in:
@@ -17,7 +17,7 @@ class RBFLayer(tf.keras.layers.Layer):
|
||||
call(inputs):
|
||||
TF call method to implement forward pass.
|
||||
"""
|
||||
def __init__(self, n_centers, input_dim=3):
|
||||
def __init__(self, n_centers, input_dim):
|
||||
""" Constructs RBF centers and standard deviations as trainable weights.
|
||||
|
||||
Parameters
|
||||
@@ -25,7 +25,7 @@ class RBFLayer(tf.keras.layers.Layer):
|
||||
n_centers : int
|
||||
The number of RBF centers.
|
||||
input_dim : int
|
||||
The dimensions of the RBF centers. Default of 3 for Kp, Ki, Kd
|
||||
The dimensions of the RBF centers.
|
||||
"""
|
||||
super().__init__()
|
||||
self.n_centers = n_centers
|
||||
@@ -66,22 +66,26 @@ class RBFAdaptiveModel(tf.keras.Model):
|
||||
----------
|
||||
n_centers : int
|
||||
The number of RBF centers.
|
||||
imput_dim : int
|
||||
The dimensions of the RBF centers.
|
||||
|
||||
Methods
|
||||
-------
|
||||
call(inputs):
|
||||
TF call method to implement forward pass of model.
|
||||
"""
|
||||
def __init__(self, n_centers):
|
||||
def __init__(self, n_centers, input_dim=3):
|
||||
""" Constructs RBF and output layers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_centers : int
|
||||
The number of RBF centers.
|
||||
input_dim : int
|
||||
The dimensions of the RBF centers. Default of 3 for Kp, Ki, Kd
|
||||
"""
|
||||
super().__init__()
|
||||
self.rbf_layer = RBFLayer(n_centers)
|
||||
self.rbf_layer = RBFLayer(n_centers, input_dim)
|
||||
self.output_layer = tf.keras.layers.Dense(1)
|
||||
|
||||
def call(self, inputs):
|
||||
|
||||
Reference in New Issue
Block a user