diff --git a/byol/utils/networks.py b/byol/utils/networks.py index e2e4bf4..7b5b491 100644 --- a/byol/utils/networks.py +++ b/byol/utils/networks.py @@ -56,7 +56,7 @@ class ResNetTorso(hk.Module): def __init__( self, blocks_per_group: Sequence[int], - num_classes: int = None, + num_classes: Optional[int] = None, bn_config: Optional[Mapping[str, float]] = None, resnet_v2: bool = False, bottleneck: bool = True, diff --git a/mmv/models/mm_embeddings.py b/mmv/models/mm_embeddings.py index 0de0e15..d54682b 100644 --- a/mmv/models/mm_embeddings.py +++ b/mmv/models/mm_embeddings.py @@ -16,7 +16,7 @@ # Lint as: python3. """Model for text-video-audio embeddings.""" -from typing import Any, Dict +from typing import Any, Dict, Optional import haiku as hk import jax @@ -358,7 +358,7 @@ class EmbeddingModule(hk.Module): embedding_dim: int, mode: str = "linear", use_bn_out: bool = False, - bn_config: Dict[str, Any] = None, + bn_config: Optional[Dict[str, Any]] = None, use_xreplica_bn: bool = True, name="embedding_module"): self._embedding_dim = embedding_dim