import numpy as np
from scipy.special import factorial
import matplotlib.pyplot as plt
from angstrom.pyramids.pyramid_utils import get_polar_grid, get_filter_crops
# Steerable Pyramid Base class
[docs]
class SteerablePyramid():
[docs]
def __init__(
self,
depth,
orientations,
filters_per_octave=1,
twidth=1,
complex_pyr=False):
"""
depth - Pyramid Depth (height) of Pyramid Decomposition
orientations - number of orientations (number of bands)
filters_per_octave - Nubmer of Filters in a single Octave
twidth - transition width between Low and High Pass Filters (controls abruptness of the falloff)
complex_pyr - determines whether to create a Complex Pyramid
Recommended inputs:
- Single Octave BW --> default args
- Half Octave BW --> filters_per_octave=2, twidth=0.5
- A smaller twdith (transition region) appears to provide
a better reconstructon error.
"""
# max_depth = int(np.floor(np.log2(np.min(np.array(image.shape)))) - 2)
self.depth = depth
self.orientations = orientations
self.twidth = twidth
self.complex_pyr = complex_pyr
# number of filters in each band (does not include hi and lo pass)
self.num_filts = depth * filters_per_octave
# octaves per filter (bandwidth in terms of octaves)
self.octave_bw = 1.0 / filters_per_octave
def _get_radial_mask(self, radius, r):
""" Obtains Radial High and Low Pass Filters
Inputs:
radius - radius of polar grid
r - specifies filter boundary where hi is 1 and where lo is mostly 1
Outputs:
lo_mask - Low Pass Filter
hi_mask - High Pass Filter
"""
# shift log radius (shifts by an octave if log2(r) = 1)
log_rad = np.log2(radius) - np.log2(r)
hi_mask = np.clip(log_rad, -self.twidth, 0)
hi_mask = np.abs(np.cos(hi_mask * np.pi / (2 * self.twidth)))
lo_mask = np.sqrt(1.0 - hi_mask**2)
return lo_mask, hi_mask
def _get_angle_mask(self, angle, b):
""" Obtains Angle Mask
Inputs:
angle - Angular component of polar coordinate grid
b - bth band or orientation
Outputs:
angle_mask - Angular mask
"""
order = self.orientations - 1
const = np.power(2,
(2 * order)) * np.power(factorial(order),
2) / (self.orientations * factorial(2 * order))
angle = np.mod(
np.pi + angle - np.pi * b / self.orientations,
2 * np.pi) - np.pi
if self.complex_pyr:
# complex (only use single lobe due to conjugate symmetry)
angle_mask = 2 * \
np.sqrt(const) * np.power(np.cos(angle), order) * (np.abs(angle) < np.pi / 2)
else:
# non-complex take magnitude to ensure both lobes are acquired
angle_mask = np.abs(
2 *
np.sqrt(const) *
np.power(
np.cos(angle),
order))
return angle_mask
[docs]
def get_filters(self, image, cropped=False):
""" Obtains cropped? Steerable Pyramid Filters
Inputs:
image - input single channel image
full - determines whether or not to crop the filters
Outputs:
filters - list of pyramid filters
crops - crop indices associated with each filter
"""
h, w = image.shape
angle, radius = get_polar_grid(h, w)
# radial_vals specify radial spacing between adjacent filters
# they determine the lo/hi cutoffs
radial_vals = 2.0**np.arange(-self.depth,
self.octave_bw, self.octave_bw)[::-1]
# get initial Low and High Pass Filters
lo_mask_prev, hi_mask = self._get_radial_mask(radius, r=radial_vals[0])
# get initial crop index
crop = get_filter_crops(hi_mask)
crops = [crop]
if cropped:
filters = [hi_mask[crop[0]:crop[1], crop[2]:crop[3]]]
else:
filters = [hi_mask]
for idx, rval in enumerate(radial_vals[1:]):
# obtain Radial Band Filter Mask
lo_mask, hi_mask = self._get_radial_mask(radius, rval)
rad_mask = hi_mask * lo_mask_prev
# obtain crops indexes for current level
if idx > 0:
crop = get_filter_crops(rad_mask)
# get filters at each band (orientation)
for b in range(self.orientations):
# get Anglular Filter Mask
angle_mask = self._get_angle_mask(angle, b)
filt = rad_mask * angle_mask / 2
if cropped:
filters.append(filt[crop[0]:crop[1], crop[2]:crop[3]])
else:
filters.append(filt)
# store crop dimensions for current Pyramid Level
crops.append(crop)
lo_mask_prev = lo_mask
# get final Low Pass Filter Mask and crop dims
crop = get_filter_crops(lo_mask)
crops.append(crop)
if cropped:
filters.append(lo_mask[crop[0]:crop[1], crop[2]:crop[3]])
else:
filters.append(lo_mask)
return filters, crops
[docs]
def build_pyramid(self, image, cropped_filters, crops, freq=False):
""" Build Pyramid Decomposition
Inputs:
image - input single channel image
cropped_filters - cropped filters
crops - filter crop indices
Outputs:
pyramid - output list of pyramid decomposition
"""
image_dft = np.fft.fftshift(np.fft.fft2(image))
pyramid = []
idx = 0
# Highpass (first filter)
filt, crop = cropped_filters[idx], crops[idx]
dft = image_dft[crop[0]:crop[1], crop[2]:crop[3]] * filt
if freq:
pyramid.append(dft)
elif self.complex_pyr:
pyramid.append(np.fft.ifft2(np.fft.ifftshift(dft)))
else:
pyramid.append(np.fft.ifft2(np.fft.ifftshift(dft)).real)
idx += 1
# Bandpass levels (grouped by orientations)
for band_level in range(self.num_filts):
bands = []
for b in range(self.orientations):
filt, crop = cropped_filters[idx], crops[idx]
dft = image_dft[crop[0]:crop[1], crop[2]:crop[3]] * filt
if freq:
band = dft
elif self.complex_pyr:
band = np.fft.ifft2(np.fft.ifftshift(dft))
else:
band = np.fft.ifft2(np.fft.ifftshift(dft)).real
bands.append(band)
idx += 1
pyramid.append(bands)
# Lowpass (last filter)
filt, crop = cropped_filters[idx], crops[idx]
dft = image_dft[crop[0]:crop[1], crop[2]:crop[3]] * filt
if freq:
pyramid.append(dft)
elif self.complex_pyr:
pyramid.append(np.fft.ifft2(np.fft.ifftshift(dft)))
else:
pyramid.append(np.fft.ifft2(np.fft.ifftshift(dft)).real)
return pyramid
[docs]
def build_pyramid_full(self, image, filters, freq=False):
""" Vectorized Pyramid Decomposition with uncropped filters array
Inputs:
image - input single channel image
filters - uncropped filters array
freq - determines whether to build frequency domain pyramid or spatial domain
Outputs:
pyramid - output list of pyramid decomposition
"""
image_dft = np.fft.fftshift(np.fft.fft2(image))[None, :, :]
dft = image_dft * filters
if freq:
return dft
if self.complex_pyr:
pyramid = np.fft.ifft2(np.fft.ifftshift(dft, axes=(1, 2)))
else:
pyramid = np.fft.ifft2(np.fft.ifftshift(dft, axes=(1, 2))).real
return pyramid
[docs]
def reconstruct_image_dft(
self,
pyramid,
cropped_filters,
crops,
freq=False):
"""
Reconstructs image DFT from the pyramid decomposition.
Accepts grouped (bandpass as lists) or flat structure.
"""
# Always flatten the pyramid for legacy code
pyramid = flatten_pyramid(pyramid)
h, w = pyramid[0].shape
recon_dft = np.zeros((h, w), dtype=np.complex128)
for i, (pyr, filt, crop) in enumerate(
zip(pyramid, cropped_filters, crops)):
# dft of sub band
if freq:
dft = pyr
else:
dft = np.fft.fftshift(np.fft.fft2(pyr))
# accumulate reconstructed sub bands
if self.complex_pyr and (i != 0) and (
i != (len(cropped_filters) - 1)):
recon_dft[crop[0]:crop[1], crop[2]:crop[3]] += 2.0 * dft * filt
else:
recon_dft[crop[0]:crop[1], crop[2]:crop[3]] += dft * filt
return recon_dft
[docs]
def reconstruct_image_dft_full(self, pyramid, filters, freq=False):
""" Reconstructs image DFT from the pyramid decomposition with full pyramid
and uncropped filters
Inputs:
pyramid - Complex Steerable Pyramid Decomposition
(either spatial or frequency domain)
filters - uncropped filters
freq - flag to denote whether input pyramid is in frequency space
Outputs:
recon_dft - reconstructed image DFT
"""
h, w = pyramid[0].shape
recon_dft = np.zeros((h, w), dtype=np.complex128)
for i, (pyr, filt) in enumerate(zip(pyramid, filters)):
# dft of sub band
if freq:
dft = pyr
else:
dft = np.fft.fftshift(np.fft.fft2(pyr))
# accumulate reconstructed sub bands
if self.complex_pyr and (i != 0) and (i != (len(filters) - 1)):
recon_dft += 2.0 * dft * filt
else:
recon_dft += dft * filt
return recon_dft
[docs]
def reconstruct_image(
self,
pyramid,
filters,
crops=None,
full=False,
freq=False):
""" Reconstructs image from the pyramid decomposition.
Inputs:
pyramid - Complex Steerable Pyramid Decomposition
filters - cropped filters
crops - filter crop indices
full - denotes whether to use full or cropped pyramid approach
freq - denotes whether input pyramid is in Frequency or Spatial Domain
Outputs:
recon_dft - reconstructed image DFT
"""
if full:
recon_dft = self.reconstruct_image_dft_full(pyramid, filters, freq)
else:
recon_dft = self.reconstruct_image_dft(
pyramid, filters, crops, freq)
return np.fft.ifft2(np.fft.ifftshift(recon_dft)).real
[docs]
def display(self, filters, title=""):
""" Displays all Pyramid Filters except for Hi and Lo pass masks
Inputs:
filters - cropped filters list or pyramid list
title - title for figure
"""
fig, ax = plt.subplots(
self.num_filts, self.orientations, figsize=(
30, 20))
fig.suptitle(title, size=22)
idx = 0
for i in range(self.num_filts):
idx = i * self.orientations
for j in range(1, self.orientations + 1):
jdx = idx + j
ax[i][j - 1].imshow(filters[jdx])
plt.tight_layout()
return fig, ax
# Sub Octave Smooth window Pyramid class
[docs]
class SuboctaveSP(SteerablePyramid):
[docs]
def __init__(
self,
depth,
orientations,
filters_per_octave,
cos_order=6,
complex_pyr=True):
"""
depth - Pyramid Depth (height)
orientations - number of orientations
filters_per_octave - number of filters per octave (specifies frequency spacing of adjacent filters)
cos_order - order of cosine smoothing function
complex_pyr - determines whether to create a Complex Pyramid
NOTE: there is an issue with non-complex pyramids, not sure what it is
"""
self.depth = depth
self.num_filts = depth * filters_per_octave
self.orientations = orientations
self.filters_per_octave = filters_per_octave
self.cos_order = cos_order
self.complex_pyr = complex_pyr
def _get_angle_mask_smooth(self, angle, b):
""" Obtains Angle Mask
Inputs:
angle - Angular component of polar coordinate grid
b - bth band or orientation
Outputs:
angle_mask - Angular mask
"""
order = self.orientations - 1
const = np.power(2, (2 * order)) \
* np.power(factorial(order), 2) \
/ (self.orientations * factorial(2 * order))
angle = np.mod(
np.pi + angle - np.pi * b / self.orientations,
2 * np.pi) - np.pi
# if self.complex_pyr:
# # complex (only use single lobe due to conjugate symmetry)
# angle_mask = np.sqrt(const) \
# * np.power(np.cos(angle), order) \
# * (np.abs(angle) < np.pi/2)
# else:
# # non-complex
# angle_mask = np.abs(np.sqrt(const) \
# * np.power(np.cos(angle), order))
angle_mask = np.sqrt(const) \
* np.power(np.cos(angle), order) \
* (np.abs(angle) < np.pi / 2)
return angle_mask
[docs]
@staticmethod
def window_func(x, center):
return np.abs(x - center) < np.pi / 2
[docs]
def get_filters(self, image, cropped=False):
""" Builds Filters
Inputs:
image - input image
cropped - determines whether to crop filters or not
Outputs:
filters - list of pyramid filters
crops - crop indices associated with each filter
"""
h, w = image.shape
angle, radius = get_polar_grid(h, w)
# get log radius
rad = np.log2(radius)
rad = (self.depth + rad) / self.depth
rad = rad * (np.pi / 2 + np.pi / 7 * self.num_filts)
# Build Radial Filters
rad_filters = []
total = np.zeros((h, w))
const = np.power(2, 2 * self.cos_order) \
* np.power(factorial(self.cos_order), 2) \
/ ((self.cos_order + 1) * factorial(2 * self.cos_order))
for k in reversed(range(self.num_filts)):
shift = np.pi / (self.cos_order + 1) * (k + 1) + 2 * np.pi / 7
rad_filters.append(np.sqrt(const)
* np.power(np.cos(rad - shift), self.cos_order)
* self.window_func(rad, shift))
total += rad_filters[-1]**2
# get lo and hi pass filters
dims = np.array([h, w])
center = np.ceil(dims / 2).astype(int)
lodims = np.ceil(center / 4).astype(int)
idx11 = center[0] - lodims[0]
idx12 = center[0] + lodims[0]
idx21 = center[1] - lodims[1]
idx22 = center[1] + lodims[1]
total_crop = total[idx11:idx12, idx21:idx22]
lopass = np.zeros((h, w))
lopass[idx11:idx12, idx21:idx22] = np.sqrt(np.abs(1 - total_crop))
hipass = np.sqrt(np.abs(1 - (total + lopass**2)))
# build angle masks
angle_masks = []
for b in range(self.orientations):
angle_masks.append(self._get_angle_mask_smooth(angle, b))
# Get Sub Band Filters and Crops
filters = []
crops = []
filters = [hipass]
crops.append(get_filter_crops(hipass))
for rad_filt in rad_filters:
for ang_mask in angle_masks:
filt = rad_filt * ang_mask
crop = get_filter_crops(filt)
crops.append(crop)
if cropped:
filters.append(filt[crop[0]:crop[1], crop[2]:crop[3]])
else:
filters.append(filt)
crop = get_filter_crops(lopass)
crops.append(crop)
if cropped:
filters.append(lopass[crop[0]:crop[1], crop[2]:crop[3]])
else:
filters.append(lopass)
return filters, crops
[docs]
def flatten_pyramid(pyramid):
flat = []
for level in pyramid:
if isinstance(level, list):
flat.extend(level)
else:
flat.append(level)
return flat