Source code for benchbuild.source.base

"""
Provide a base interface for downloadable sources.
"""
import abc
import itertools
import sys
import typing as tp
from typing import Protocol

import attr
import plumbum as pb

from benchbuild.settings import CFG

if tp.TYPE_CHECKING:
    from benchbuild.project import Project


[docs] @attr.s(frozen=True, eq=True) class RevisionStr: # pylint: disable=too-few-public-methods value: str = attr.ib()
[docs] @attr.s(frozen=True, eq=True) class Variant: """ Provide a 'string'-like wrapper around source version information. Use this, whenever you need a 'version' string somewhere in benchbuild. In terms of output/logging or use as program arguments, this should not carry more semantics than a simple version string. However, this wrapper is linked to its 'owner'. The owner serves as the back-reference to the source code where it originated from. This can serve as a 'hook' to deal with version information the same way as a program variant like a specific configuraiton. """ owner: 'FetchableSource' = attr.ib(eq=False, repr=False) version: str = attr.ib()
[docs] def name(self) -> str: return self.owner.local
[docs] def source(self) -> 'FetchableSource': return self.owner
def __str__(self) -> str: return str(self.version)
NestedVariants = tp.Iterable[tp.Tuple[Variant, ...]]
[docs] class Revision: """ A revision captures all variants that form a single project revision. A project may have an arbitrary number of input sources that are required for it's defined workloads, e.g., test input files, optional dependencies, or submodules. BenchBuild considers each source to have different version numbers, encoded as "Variants". The complete set of "Variants" for a project then forms a project revision. """ project_cls: tp.Type["Project"] variants: tp.Sequence[Variant] def __init__( self, project_cls: tp.Type["Project"], _primary: Variant, *variants: Variant ) -> None: self.project_cls = project_cls self.variants = [_primary] + list(variants)
[docs] def extend(self, *variants: Variant) -> None: self.variants = list(self.variants) + list(variants)
def __update_variant(self, variant: Variant) -> None: def __replace(elem: Variant): if elem.name() == variant.name(): return variant return elem self.variants = list(map(__replace, self.variants))
[docs] def update(self, revision: "Revision") -> None: for variant in revision.variants: self.__update_variant(variant)
[docs] def variant_by_name(self, name: str) -> Variant: """ Return the variant for the given source name. Args: name: The local name of the source. Returns: then version of the found source. """ found_variants = [ variant for variant in self.variants if variant.owner.key == name ] if len(found_variants) > 0: return found_variants[0] raise KeyError(f"Source with name {name} not found.")
[docs] def has_variant(self, name: str) -> bool: """ Check if a variant with the given source name exists. Args: name: The local name of the source. Returns: True, should a variant with the given name exists """ return any(variant.owner.key == name for variant in self.variants)
[docs] def source_by_name(self, name: str) -> 'FetchableSource': """ Return the source object that matches the key. Args: name: The local name of the source. Returns: the found source object Raises: KeyError, if we cannot find the source with this name. """ found_variants = [ variant for variant in self.variants if variant.owner.key == name ] if len(found_variants) > 0: return found_variants[0].owner raise KeyError(f"Source with name {name} not found.")
@property def primary(self) -> Variant: variants = self.sorted() return variants[0]
[docs] def sorted(self) -> tp.Sequence[Variant]: """ Return an ordered list of Variants from this revision. The order is defined by the order in the SOURCE attribute of the associated project class. """ sources = self.project_cls.SOURCE sources = [src for src in sources if self.has_variant(src.key)] return [self.variant_by_name(src.key) for src in sources]
def __str__(self) -> str: variants = self.sorted() return to_str(*variants) def __repr__(self) -> str: return str(self)
[docs] def to_str(*variants: Variant) -> str: """ Convert an arbitrary number of variants into their string representation. Returns: string representation of all input variants joined by ','. """ return ",".join([str(i) for i in variants])
[docs] class Fetchable(Protocol): @property def key(self) -> str: """ Return the source's key property. This provides you with a key component that identifes a single source. It should (no guarantee) be unique among all sources for this project. While this make no further assumption, but a good candidate is a file-system name/path. """ @property def local(self) -> str: """ The source location (path-like) after fetching it from its remote. """ @property def remote(self) -> tp.Union[str, tp.Dict[str, str]]: """ The source location in the remote location. """
[docs] def fetch(self) -> pb.LocalPath: """ Fetch the necessary source files into benchbuild's cache. """
[docs] class Expandable(Protocol): @property def is_expandable(self) -> bool: """ Returns true, if this source should take part in version expansion. Some sources may only be treated as virtual and would not take part in the version expansion of an associated project. """
[docs] def versions(self) -> tp.Sequence[Variant]: """ List all available versions of this source. Returns: List[str]: The list of all available versions. """
[docs] class ContextAwareSource(Protocol):
[docs] def is_context_free(self) -> bool: """ Return, if this source needs context to evaluate it's own list of available versions. """
[docs] def versions_with_context(self, ctx: Revision) -> tp.Sequence[Variant]: """ Augment the given revision with new variants associated with this source. Args: ctx: the project revision, containing information about every context-free variant. Returns: a sequence of project revisions. """
[docs] class ContextFreeMixin: """ Make a context-free source context-aware. This will setup default implementations that avoids interaction with any context. """
[docs] def is_context_free(self) -> bool: return True
[docs] def versions_with_context(self, ctx: Revision) -> tp.Sequence[Variant]: raise AttributeError("Invalid use of versions with context")
[docs] class Versioned(Protocol): @property def default(self) -> Variant: """ The default version for this source. """
[docs] def version(self, target_dir: str, version: str) -> pb.LocalPath: """ Fetch the requested version and place it in the target_dir Args: target_dir (str): The filesystem path where the version should be placed in. version (str): The version that should be fetched from the local cache. Returns: str: [description] """
[docs] def versions(self) -> tp.Sequence[Variant]: """ List all available versions of this source. Returns: List[str]: The list of all available versions. """
[docs] class FetchableSource(ContextFreeMixin): """ Base class for fetchable sources. Subclasses have to provide the following protocols: - Expandable - Fetchable - Versioned """ _local: str _remote: tp.Union[str, tp.Dict[str, str]] def __init__(self, local: str, remote: tp.Union[str, tp.Dict[str, str]]): super().__init__() self._local = local self._remote = remote @property def local(self) -> str: return self._local @property def remote(self) -> tp.Union[str, tp.Dict[str, str]]: return self._remote @property def key(self) -> str: return self.local @property @abc.abstractmethod def default(self) -> Variant: """ The default version for this source. """ raise NotImplementedError()
[docs] @abc.abstractmethod def version(self, target_dir: str, version: str) -> pb.LocalPath: """ Fetch the requested version and place it in the target_dir Args: target_dir (str): The filesystem path where the version should be placed in. version (str): The version that should be fetched from the local cache. Returns: str: [description] """ raise NotImplementedError()
[docs] @abc.abstractmethod def versions(self) -> tp.Sequence[Variant]: """ List all available versions of this source. Returns: List[str]: The list of all available versions. """ raise NotImplementedError()
[docs] def explore(self) -> tp.Sequence[Variant]: """ Explore revisions of this source. This provides access to all revisions this source can offer. BenchBuild own filters will not block any revision here. Custom sources or source filters can opt in to block revisions anyways. Returns: List[str]: The list of versions to explore. """ if self.is_context_free(): return self.versions() return []
@property def is_expandable(self) -> bool: return True
[docs] @abc.abstractmethod def fetch(self) -> pb.LocalPath: """ Fetch the necessary source files into benchbuild's cache. """ raise NotImplementedError()
Sources = tp.List['FetchableSource']
[docs] class NoSource(FetchableSource): @property def default(self) -> Variant: return Variant(owner=self, version='None')
[docs] def version(self, target_dir: str, version: str) -> pb.LocalPath: return 'None'
[docs] def versions(self) -> tp.List[Variant]: return [Variant(owner=self, version='None')]
[docs] def fetch(self) -> pb.LocalPath: return 'None'
[docs] def nosource() -> NoSource: return NoSource('NoSource', 'NoSource')
[docs] def target_prefix() -> str: """ Return the prefix directory for all downloads. Returns: str: the prefix where we download everything to. """ return str(CFG['tmp_dir'])
SourceT = tp.TypeVar('SourceT')
[docs] def primary(*sources: SourceT) -> SourceT: """ Return the implicit 'main' source of a project. We define the main source as the first source listed in a project. If you define a new project and rely on the existence of a 'main' source code repository, make sure to define it as the first one. """ (head, *_) = sources return head
[docs] def secondaries(*sources: SourceT) -> tp.Sequence[SourceT]: """ Return the complement to the primary source of a project. Returns: A list of all sources not considered primary. """ (_, *tail) = sources return list(tail)
[docs] def product(*sources: Expandable) -> NestedVariants: """ Return the cross product of the given sources. Returns: An iterable containing the cross product of all source variants. """ siblings = [source.versions() for source in sources if source.is_expandable] return itertools.product(*siblings)
[docs] class BaseSource(Expandable, Versioned, ContextAwareSource, Protocol): """ Composition of source protocols. """
[docs] class EnumeratorFn(Protocol): def __call__(self, *source: Expandable) -> NestedVariants: """ Return an enumeration of all variants for each source. Returns: a list of version tuples, containing each possible variant. """
def _default_enumerator(*sources: Expandable) -> NestedVariants: return product(*sources)
[docs] class ContextEnumeratorFn(Protocol): def __call__( self, project_cls: tp.Type["Project"], context: Revision, *sources: ContextAwareSource ) -> tp.Sequence[Revision]: """ Enumerate all revisions that are valid under the given context. """
def _default_caw_enumerator( project_cls: tp.Type["Project"], context: Revision, *sources: ContextAwareSource ) -> tp.Sequence[Revision]: """ Transform given variant into a list of variants to check. This only considers the given context of all context-free sources per context-sensitive source. Args: context: *sources: """ variants = [source.versions_with_context(context) for source in sources] variants = [var for var in variants if var] ret = [ Revision(project_cls, *(list(context.variants) + list(caw_variants))) for caw_variants in itertools.product(*variants) ] return ret
[docs] def enumerate_revisions( project_cls: tp.Type["Project"], context_free_enumerator: EnumeratorFn = _default_enumerator, context_aware_enumerator: ContextEnumeratorFn = _default_caw_enumerator ) -> tp.Sequence[Revision]: """ Enumerates the given sources. The enumeration requires two phases. 1. A phase for all sources that do not require a context to evaluate. 2. A phase for all sources that require a static context. """ sources = project_cls.SOURCE context_free_sources = [ source for source in sources if source.is_context_free() ] context_aware_sources = [ source for source in sources if not source.is_context_free() ] revisions = context_free_enumerator(*context_free_sources) project_revisions = [ Revision(project_cls, *variants) for variants in revisions ] if len(context_aware_sources) > 0: revs = list( itertools.chain( *( context_aware_enumerator( project_cls, rev, *context_aware_sources ) for rev in project_revisions ) ) ) return revs return project_revisions
SourceContext = tp.Dict[str, Fetchable]
[docs] def sources_as_dict(*sources: Fetchable) -> SourceContext: """ Convert fetchables to a dictionary. The dictionary will be indexed by the Fetchable's local attribute. Args: *sources: Fetchables stored in the dictionary. """ return {src.local: src for src in sources}
[docs] def revision_from_str( revs: tp.Sequence[RevisionStr], project_cls: tp.Type["Project"] ) -> Revision: """ Create a Revision from a sequence of revision strings. A valid Revision can only be created, if the number of valid revision strings is equivalent to the number of sources. A valid revision string is one that has been found in the a source's version. It is required that each revision string is found in a different source version. We assume that the first source is the primary source of the revision. Args: revs: sequence of revision strings, e.g. a commit-hash. *sources: sources of a project. Returns: A variant context. """ found: tp.List[Variant] = [] sources = project_cls.SOURCE for source in sources: if source.is_expandable: found.extend([ variant for variant in source.explore() for rev in revs if variant.version == rev.value ]) if len(found) == 0: raise ValueError(f'Revisions {revs} not found in any available source.') return Revision(project_cls, primary(*found), *secondaries(*found))