from typing import Optional, Tuple
import numpy as np
import syna
from syna import utils
from syna.core import Function, Tensor, as_tensor
# ----------------------
# Shape manipulations
# ----------------------
[docs]
class Reshape(Function):
"""Reshape tensor to a given shape."""
def __init__(self, shape: Tuple[int, ...]) -> None:
self.shape = shape
[docs]
def forward(self, x):
self.x_shape = x.shape
return x.reshape(self.shape)
[docs]
def backward(self, gy):
return reshape(gy, self.x_shape)
[docs]
def reshape(x, shape) -> Tensor:
"""Reshape tensor; if shape matches returns as_tensor(x)."""
if x.shape == shape:
return as_tensor(x)
return Reshape(shape)(x)
[docs]
class Transpose(Function):
"""Transpose with optional axes permutation."""
def __init__(self, axes=None) -> None:
self.axes = axes
[docs]
def forward(self, x):
return x.transpose(self.axes)
[docs]
def backward(self, gy):
if self.axes is None:
return transpose(gy)
axes_len = len(self.axes)
inv_axes = tuple(np.argsort([ax % axes_len for ax in self.axes]))
return transpose(gy, inv_axes)
[docs]
def transpose(x, axes=None) -> Tensor:
"""Transpose tensor along axes."""
return Transpose(axes)(x)
[docs]
class GetItem(Function):
"""Supports x[slices] and produces gradient via GetItemGrad."""
def __init__(self, slices):
self.slices = slices
[docs]
def forward(self, x):
return x[self.slices]
[docs]
def backward(self, gy):
(x,) = self.inputs
return GetItemGrad(self.slices, x.shape)(gy)
[docs]
class GetItemGrad(Function):
"""Gradient for getitem: scatters gy back into the original shape."""
def __init__(self, slices, in_shape):
self.slices = slices
self.in_shape = in_shape
[docs]
def forward(self, x):
gx = np.zeros(self.in_shape, dtype=x.dtype)
np.add.at(gx, self.slices, x)
return gx
[docs]
def get_item(x, slices) -> Tensor:
"""Index into tensor with slices."""
return GetItem(slices)(x)
[docs]
def expand_dims(x, axis: int) -> Tensor:
"""Insert a dimension of size 1 at index axis."""
x = as_tensor(x)
shape = list(x.shape)
shape.insert(axis, 1)
return reshape(x, tuple(shape))
[docs]
def flatten(x) -> Tensor:
"""Flatten all dimensions except the first (batch) dimension."""
return reshape(x, (x.shape[0], -1))
# ----------------------
# Reductions & broadcasting
# ----------------------
[docs]
class Sum(Function):
"""Sum reduction, supports axis and keepdims."""
def __init__(self, axis, keepdims) -> None:
self.axis = axis
self.keepdims = keepdims
[docs]
def forward(self, x):
self.x_shape = x.shape
return x.sum(axis=self.axis, keepdims=self.keepdims)
[docs]
def backward(self, gy):
gy = utils.reshape_sum_backward(gy, self.x_shape, self.axis, self.keepdims)
return broadcast_to(gy, self.x_shape)
[docs]
def sum(x, axis: Optional[Tuple[int, ...]] = None, keepdims=False) -> Tensor:
"""Sum elements along given axes."""
return Sum(axis, keepdims)(x)
[docs]
class SumTo(Function):
"""Sum elements to target shape (inverse of broadcast_to)."""
def __init__(self, shape: Tuple[int, ...]):
self.shape = shape
[docs]
def forward(self, x):
self.x_shape = x.shape
return utils.sum_to(x, self.shape)
[docs]
def backward(self, gy):
return broadcast_to(gy, self.x_shape)
[docs]
def sum_to(x, shape: Tuple[int, ...]) -> Tensor:
"""Sum elements of x so result has `shape`."""
if x.shape == shape:
return as_tensor(x)
return SumTo(shape)(x)
[docs]
class BroadcastTo(Function):
"""Broadcast x to shape."""
def __init__(self, shape: Tuple[int, ...]) -> None:
self.shape = shape
[docs]
def forward(self, x):
self.x_shape = x.shape
return np.broadcast_to(x, self.shape)
[docs]
def backward(self, gy):
return sum_to(gy, self.x_shape)
[docs]
def broadcast_to(x, shape: Tuple[int, ...]) -> Tensor:
"""Broadcast x to the given shape."""
if x.shape == shape:
return as_tensor(x)
return BroadcastTo(shape)(x)
[docs]
def dropout(x, dropout_ratio=0.5) -> Tensor:
"""Dropout during training; identity during evaluation."""
x = as_tensor(x)
if syna.Config.train:
mask = np.random.rand(*x.shape) > dropout_ratio
scale = np.array(1.0 - dropout_ratio).astype(x.dtype)
return x * mask / scale
return x