mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2025-12-17 14:14:15 +08:00
128 lines
3.9 KiB
Python
128 lines
3.9 KiB
Python
# Copyright 2021 DeepMind Technologies Limited.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS-IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Give names to parts of a numpy array."""
|
|
|
|
from typing import Iterable, List, Mapping, MutableMapping, Tuple, Union
|
|
|
|
import numpy as np
|
|
|
|
|
|
def lengths_to_ranges(
|
|
lengths: Mapping[str, int]) -> MutableMapping[str, List[int]]:
|
|
"""Eg: {a: 2, b: 3} -> {a: [0, 1], b: [2, 3, 4]} ."""
|
|
ranges = {}
|
|
start = 0
|
|
for key, length in lengths.items():
|
|
ranges[key] = list(range(start, start + length))
|
|
start += length
|
|
return ranges
|
|
|
|
|
|
class NamedRanges:
|
|
"""Given a map of {key: count}, give various views into it."""
|
|
|
|
def __init__(self, counts: Mapping[str, int]):
|
|
self._ranges = lengths_to_ranges(counts)
|
|
self._size = sum(counts.values())
|
|
|
|
def __getitem__(self, name) -> List[int]:
|
|
return self._ranges[name]
|
|
|
|
def __contains__(self, name) -> bool:
|
|
return name in self._ranges
|
|
|
|
def set_range(self, name: str, value: List[int]):
|
|
"""Overwrite or create a custom range, which may intersect with others."""
|
|
self._ranges[name] = value
|
|
|
|
def range(self, name: str) -> List[int]:
|
|
return self[name]
|
|
|
|
def index(self, name: str) -> int:
|
|
rng = self[name]
|
|
if len(rng) != 1:
|
|
raise ValueError(f"{name} has multiple values")
|
|
return rng[0]
|
|
|
|
def count(self, name: str) -> int:
|
|
return len(self[name])
|
|
|
|
def names(self) -> Iterable[str]:
|
|
return self._ranges.keys()
|
|
|
|
def ranges(self) -> Iterable[Tuple[str, List[int]]]:
|
|
return self._ranges.items()
|
|
|
|
def counts(self) -> Mapping[str, int]:
|
|
return {k: len(v) for k, v in self._ranges.items()}
|
|
|
|
@property
|
|
def size(self) -> int:
|
|
return self._size
|
|
|
|
def named_array(self, array: np.ndarray) -> "NamedArray":
|
|
return NamedArray(array, self)
|
|
|
|
def new_named_array(self) -> "NamedArray":
|
|
return NamedArray(np.zeros((self.size,)), self)
|
|
|
|
def new_random_named_array(self) -> "NamedArray":
|
|
return NamedArray(np.random.uniform(size=(self.size,)), self)
|
|
|
|
|
|
class NamedArray:
|
|
"""Given a numpy array and a NamedRange, access slices by name."""
|
|
|
|
def __init__(self, array: np.ndarray, names: NamedRanges):
|
|
if array.shape != (names.size,):
|
|
raise ValueError(f"Wrong sizes: {array.shape} != ({names.size},)")
|
|
self._array = array
|
|
self._names = names
|
|
|
|
def __getitem__(
|
|
self, name: Union[str, Tuple[str, Union[int, List[int],
|
|
slice]]]) -> np.ndarray:
|
|
"""Return a read-only view into the array by name."""
|
|
if isinstance(name, str):
|
|
arr = self._array[self._names[name]]
|
|
else:
|
|
name, i = name
|
|
arr = self._array[np.array(self._names[name])[i]]
|
|
if not np.isscalar(arr):
|
|
# Read-only because it's indexed by an array of potentially non-contiguous
|
|
# indices, which isn't representable as a normal tensor, which forces a
|
|
# copy and therefore writes don't modify the underlying array as expected.
|
|
arr.flags.writeable = False
|
|
return arr
|
|
|
|
def __setitem__(
|
|
self, name: Union[str, Tuple[str, Union[int, List[int], slice]]], value):
|
|
"""Set one or more values of a range to a value."""
|
|
if isinstance(name, str):
|
|
self._array[self._names[name]] = value
|
|
else:
|
|
name, i = name
|
|
self._array[np.array(self._names[name])[i]] = value
|
|
|
|
@property
|
|
def array(self) -> np.ndarray:
|
|
return self._array
|
|
|
|
@property
|
|
def names(self) -> NamedRanges:
|
|
return self._names
|
|
|
|
def to_dict(self) -> Mapping[str, np.ndarray]:
|
|
return {k: self[k] for k in self._names.names()}
|