Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions cebra/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,25 +154,22 @@ def expand_index_in_trial(self, index, trial_ids, trial_borders):
trial_ids is in size of a length of self.index and indicate the trial id of the index belong to.
trial_borders is in size of a length of self.idnex and indicate the border of each trial.

Todo:
- rewrite
"""

# TODO(stes) potential room for speed improvements by pre-allocating these tensors/
# using non_blocking copy operation.
offset = torch.arange(-self.offset.left,
self.offset.right,
device=index.device)
index = torch.tensor(
[
torch.clamp(
i,
trial_borders[trial_ids[i]] + self.offset.left,
trial_borders[trial_ids[i] + 1] - self.offset.right,
) for i in index
],
device=self.device,
)

# Vectorized lookup and boundary calculation
Comment thread
stes marked this conversation as resolved.
Outdated
batch_trial_ids = trial_ids[index]
min_borders = trial_borders[batch_trial_ids] + self.offset.left
max_borders = trial_borders[batch_trial_ids + 1] - self.offset.right

# Fast C-level clamp
Comment thread
stes marked this conversation as resolved.
Outdated
index = torch.clamp(index, min=min_borders, max=max_borders)

return index[:, None] + offset[None, :]

@abc.abstractmethod
Expand Down
Loading