Giter Site home page Giter Site logo

Comments (2)

NickleDave avatar NickleDave commented on July 22, 2024

Pasting in the WindowDataset class as I have it with x_source to refer to later, about to hack this out. Key thing is to use a Literal and a constant we add to vak.constants.

from typing import Callable, Literal

...

from .helper import x_vectors_from_df


class WindowDataset(VisionDataset):
    ...
    # class attribute, constant used by several methods
    # with window_inds, to mark invalid starting indices for windows
    INVALID_WINDOW_VAL = -1

    VALID_SPLITS = ("train", "val" "test", "all")

    def __init__(
        self,
        root: str | pathlib.Path,
        x_source: Literal['audio', 'spect'],
        window_inds: npt.NDArray,
        source_ids: npt.NDArray,
        source_inds: npt.NDArray,
        source_paths: list | npt.NDArray,
        annots: list,
        labelmap: dict,
        timebin_dur: float,
        window_size: int,
        spect_key: str = "s",
        timebins_key: str = "t",
        transform: Callable | None = None,
        target_transform: Callable | None = None,
    ):
        r"""Initialize a WindowDataset instance.

        Parameters
        ----------
        root : str, Path
            Path to a .csv file that represents the dataset.
            Name 'root' is used for consistency with torchvision.datasets.
        x_source: str
            One of {'audio', 'spect'}. The source
            of the data, either audio files ('audio')
            or spectrograms in array files ('spect'),
            from which we take windows.
            These windows become the samples :math:`x`
            that are inputs for the network during training.
            ...
            """

            ...  # implement as described above using `x_source` for control flow

# note use of constant in pre-conditions
def x_vectors_from_df(
        df: pd.DataFrame,
        x_source: Literal['audio', 'spect'],
        split: str,
        window_size: int,
        audio_format: str = "wav",
        spect_key: str = "s",
        timebins_key: str = "t",
        crop_dur: int | None = None,
        timebin_dur: float | None = None,
        labelmap: dict | None = None,
) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
    r"""Get source_ids and spect_ind_vector from a dataframe
    that represents a dataset of vocalizations.

    See ``vak.datasets.WindowDataset`` for a
    detailed explanation of these vectors.

    Parameters
    ----------
    df : pandas.DataFrame
        That represents a dataset of vocalizations.
    x_source: str
        One of {'audio', 'spect'}. The source
        of the data, either audio files ('audio')
        or spectrograms in array files ('spect'),
        from which we take windows.
        These windows become the samples :math:`x`
        that are inputs for the network during training.
    window_size : int
        Size of the window, in number of time bins,
        that is taken from the audio array
        or spectrogram to become a training sample.
    audio_format : str
        Valid audio file format. One of {"wav", "cbin"}.
        Defaults to "wav".
    spect_key : str
        Key to access spectograms in array files.
        Default is "s".
    timebins_key : str
        Key to access time bin vector in array files.
        Default is "t".
    crop_dur : float
        Duration to which dataset should be "cropped". Default is None,
        in which case entire duration of specified split will be used.
    timebin_dur : float
        Duration of a single time bin in spectrograms. Default is None.
        Used when "cropping" dataset with ``crop_dur``, and required if a
        value is specified for that parameter.
    labelmap : dict
        Dict that maps labels from dataset to a series of consecutive integers.
        To create a label map, pass a set of labels to the `vak.utils.labels.to_map` function.
        Used when "cropping" dataset with ``crop_dur``
        to make sure all labels in ``labelmap`` are still
        in the dataset after cropping.
        Required if a  value is specified for ``crop_dur``.

    Returns
    -------
    source_ids : numpy.ndarray
        Represents the "id" of any spectrogram,
        i.e., the index into spect_paths that will let us load it.
    source_inds : numpy.ndarray
        Valid indices of windows we can grab from each
        audio array or spectrogram.
    window_inds : numpy.ndarray
        Vector of all valid starting indices of all windows in the dataset.
        This vector is what is used by PyTorch to determine
        the number of samples in the dataset, via the
        ``WindowDataset.__len__`` method.
        Without cropping, a dataset with ``t`` total time bins
        across all audio arrays or spectrograms will have
        (``t`` - ``window_size``) possible windows
        with indices (0, 1, 2, ..., t-1).
        But cropping with ``crop_dur`` will
        remove some of these indices.
    """
    from .class_ import WindowDataset  # avoid circular import

    # ---- pre-conditions
    if x_source not in constants.VALID_X_SOURCES:
        raise ValueError(
            f"`x_source` must be one of {constants.VALID_X_SOURCES} but was: {x_source}"
        )

from vak.

NickleDave avatar NickleDave commented on July 22, 2024

I ended up just pushing a branch "add-prep-audio-dataset" with changes I'd made in WindowDataset + helper for audio format, will find to incorporate those changes via git suffering later

from vak.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.