Source code for bionumpy.genomic_data.genomic_intervals

from abc import ABC, abstractmethod, abstractproperty, abstractclassmethod
from collections import namedtuple

import numpy as np
from typing import List, Iterable, Tuple, Dict

from .coordinate_mapping import map_locations, find_indices
from ..bnpdataclass import BNPDataClass, replace, bnpdataclass
from .genomic_track import GenomicArray, GenomicArrayNode
from .genome_context_base import GenomeContextBase
from ..datatypes import Interval, Bed6, StrandedInterval, LocationEntry, StrandedLocationEntry
from ..arithmetics.intervals import get_pileup, merge_intervals, extend_to_size, clip, get_boolean_mask, RawInterval
from ..computation_graph import StreamNode, Node, ComputationNode, compute

from ..string_array import StringArray


class GenomicPlace:
    @property
    def genome_context(self):
        return self._genome_context

    @abstractproperty
    def get_location(self, where='start'):
        return NotImplemented

    def get_data_field(self, field_name: str):
        return NotImplemented

    def set_strand(self, strand):
        self._is_stranded = True
        self._strand = strand


[docs] class GenomicLocation(GenomicPlace): '''Class representing (possibliy stranded) locations in the genome''' @abstractproperty def chromosome(self): return NotImplemented @abstractproperty def position(self): return NotImplemented @abstractproperty def strand(self): return NotImplemented @abstractmethod def is_stranded(self): return NotImplemented
[docs] @classmethod def from_fields(cls, genome_context: GenomeContextBase, chromosome: List[str], position: List[int], strand: List[str] = None) -> 'GenomicLocation': """Create genomic location from a genome context and the needed fields (chromosome and position) Parameters ---------- genome_context : GenomeContextBase Genome context object for the genome chromosome : List[str] List of chromosome position : List[int]: List of positions strand : List[str] Optional list of strand Returns ------- 'GenomicLocation' """ is_stranded = strand is not None if is_stranded: data = StrandedLocationEntry(chromosome, position, strand) else: data = LocationEntry(chromosome, position) return GenomicLocationGlobal.from_data(data, genome_context, is_stranded=is_stranded)
[docs] @classmethod def from_data(cls, data: BNPDataClass, genome_context: GenomeContextBase, is_stranded: bool = False, chromosome_name: str = 'chromosome', position_name: str = 'position', strand_name: str = 'strand') -> 'GenomicLocation': """Create GenomicLocation object from a genome context and a bnpdataclass The field names for the chromosome, positions and strand can be specified Parameters ---------- cls : 3 4 data : BNPDataClass The data containing the locations genome_context : GenomeContextBase Genome context for the genome is_stranded : bool Whether or not the locations should be stranded chromosome_name : str Name of the chromosome field in `data` position_name : str Name of the position field in the `data` strand_name : str Name if the `strand` field int the `data` Returns ------- 'GenomicLocation' """ assert all(hasattr(data, name) for name in (chromosome_name, position_name)) if is_stranded: assert hasattr(data, strand_name) return GenomicLocationGlobal(genome_context.mask_data(data), genome_context, is_stranded, {'chromosome': chromosome_name, 'position': position_name, 'strand': strand_name})
class GenomicLocationGlobal(GenomicLocation): ''' Class for genomic locations that are kept entirely in memory''' def __init__(self, locations: BNPDataClass, genome_context: GenomeContextBase, is_stranded: bool, field_dict: Dict[str, str]): self._locations = locations self._genome_context = genome_context self._is_stranded = is_stranded self._field_dict = field_dict @property def data(self): return self._locations def __replace__(self, **kwargs): kwargs = {self._field_dict[kw]: value for kw, value in kwargs.items()} return self.__class__(replace(self._locations, **kwargs), self._genome_context, self._is_stranded, self._field_dict) @property def chromosome(self): return getattr(self._locations, self._field_dict['chromosome']) @property def position(self): return getattr(self._locations, self._field_dict['position']) @property def strand(self): if not self.is_stranded(): raise ValueError('Unstranded position has not strand') return getattr(self._locations, self._field_dict['strand']) def is_stranded(self): return self._is_stranded def get_windows(self, flank: int = None, window_size: int = None) -> 'GenomicIntervals': """Create windows around the locations. `Flank specifies the flank on either side of the location. The full windows will thus be `flank*2+1` wide Parameters ---------- flank : int Flank on either side of the location Returns ------- GenomicIntervals Window intervals """ if flank is not None: assert window_size is None l_flank = flank r_flank = flank + 1 else: assert window_size is not None l_flank = window_size//2 r_flank = window_size//2 + window_size % 2 if self.is_stranded(): intervals = StrandedInterval(self.chromosome, self.position-l_flank, self.position+r_flank, self.strand) else: intervals = Interval(self.chromosome, self.position-l_flank, self.position+r_flank) return GenomicIntervalsFull(intervals, self._genome_context, is_stranded=self.is_stranded()).clip() def sorted(self) -> GenomicLocation: """Return a sorted version of the locations Sorted according the chromosome order in the `GenomeContext` Returns ------- GenomicLocation Sorted locations """ return self[np.lexsort([self.position, self.chromosome.raw()])] def __getitem__(self, idx): return self.__class__(self._locations[idx], self._genome_context, self._is_stranded, self._field_dict) def get_data_field(self, field_name: str): return getattr(self._locations, field_name) class GenomicLocationStreamed(GenomicLocation, Node): ''' Class for representing intervals that are grouped by chromosome, and where only intervals for one chromosome at the time is kept in memory ''' is_stream = True def _get_chrom_size(self, intervals: Interval): return self._genome_context.chrom_sizes[intervals.chromosome] def __str__(self): return 'GLS:' + str(self._data_node) def __repr__(self): return 'GLS:' + str(self._data_node) def __init__(self, data_node: Node, genome_context: GenomeContextBase, is_stranded=False, field_dict: Dict[str, str]=None): if field_dict is None: field_dict = {name: name for name in ['chromosome', 'positions', 'strand']} self._genome_context = genome_context self._chromosome = ComputationNode(getattr, [data_node, field_dict['chromosome']]) self._position = ComputationNode(getattr, [data_node, field_dict['position']]) if is_stranded: self._strand = ComputationNode(getattr, [data_node, field_dict['strand']]) self._chrom_size_node = StreamNode(iter(self._genome_context.chrom_sizes.values())) self._data_node = data_node self._is_stranded = is_stranded def is_stranded(self): return self._is_stranded def sorted(self): return NotImplemented @property def position(self): return self._position @property def chromosome(self): return self._chromosome def get_data_field(self, field_name: str): return ComputationNode(getattr, [self._data_node, field_name]) @property def strand(self): if not self.is_stranded(): raise ValueError('Strand not supported on unstranded intervals') return self._strand def __getitem__(self, item): return self.__class__(ComputationNode(lambda x, i: x[i], [self._data_node, item]), self._genome_context) def get_windows(self, flank: int = None, window_size: int = None) -> 'GenomicIntervals': """Create windows around the locations. `Flank specifies the flank on either side of the location. The full windows will thus be `flank*2+1` wide Parameters ---------- flank : int Flank on either side of the location Returns ------- GenomicIntervals Window intervals """ if flank is not None: assert window_size is None l_flank = flank r_flank = flank + 1 else: assert window_size is not None l_flank = window_size//2 r_flank = window_size//2 + window_size % 2 if self.is_stranded(): intervals = ComputationNode(StrandedInterval, [self.chromosome, self.position-l_flank, self.position+r_flank, self.strand]) else: intervals = ComputationNode( Interval, [self.chromosome, self.position-l_flank, self.position+r_flank]) return GenomicIntervalsStreamed(intervals, self._genome_context, is_stranded=self.is_stranded()).clip() def _get_buffer(self, i): return GenomicLocationGlobal( LocationEntry(self.chromosome._get_buffer(i), self.position._get_buffer(i)), self._genome_context)
[docs] class GenomicIntervals(GenomicPlace): ''' Class for representing intervals on a genome''' @abstractproperty def start(self): return NotImplemented @abstractproperty def stop(self): return NotImplemented @abstractproperty def chromosome(self): return NotImplemented @abstractproperty def strand(self): return NotImplemented @abstractmethod def is_stranded(self): return NotImplemented @abstractmethod def get_location(self, where: str = 'start') -> GenomicLocation: return NotImplemented
[docs] @abstractmethod def extended_to_size(self, size: int) -> 'GenomicIntervals': """Extend intervals along strand to reach the given size Parameters ---------- size : int Returns ------- 'GenomicIntervals' """ return NotImplemented
[docs] @abstractmethod def merged(self, distance: int = 0) -> 'GenomicIntervals': """Merge intervals that overlap or lie within distance of eachother Parameters ---------- distance : int Returns ------- 'GenomicIntervals' 4 """ return NotImplemented
[docs] @abstractmethod def get_mask(self) -> GenomicArray: """Return a boolean mask of areas covered by any interval Returns ------- GenomicArray Genomic mask """ return NotImplemented
[docs] @abstractmethod def get_pileup(self) -> GenomicArray: """Return a genmic track of counting the number of intervals covering each bp Returns ------- GenomicArray Pileup track """ return NotImplemented
[docs] @classmethod def from_track(cls, track: GenomicArray) -> 'GenomicIntervals': """Return intervals of contigous areas of nonzero values of track Parameters ---------- track : GenomicArray Returns ------- 'GenomicIntervals' """ if isinstance(track, GenomicArrayNode): return GenomicIntervalsStreamed(track.get_data(), track.genome_context) return GenomicIntervalsFull(track.get_data(), track.genome_context)
@classmethod def from_fields(cls, genome_context: GenomeContextBase, chromosome, start, stop, strand=None): is_stranded = strand is not None if is_stranded: intervals = Bed6(chromosome, start, stop, ['.']*len(start), np.zeros_like(start), strand) else: intervals = Interval(chromosome, start, stop) return cls.from_intervals(intervals, genome_context, is_stranded=is_stranded)
[docs] @classmethod def from_intervals(cls, intervals: Interval, genome_context: GenomeContextBase, is_stranded=False): """Create genomic intervals from interval entries and genome info Parameters ---------- intervals : Interval chrom_sizes : Dict[str, int] """ if isinstance(intervals, Interval): #TODO check is node return GenomicIntervalsFull(genome_context.mask_data(intervals), genome_context, is_stranded) else: return cls.from_interval_stream(intervals, genome_context, is_stranded)
[docs] @classmethod def from_interval_stream(cls, interval_stream: Iterable[Interval], genome_context: GenomeContextBase, is_stranded=False): """Create streamed genomic intervals from a stream of intervals and genome info Parameters ---------- interval_stream : Iterable[Interval] chrom_sizes : Dict[str, int] """ interval_stream = genome_context.iter_chromosomes( interval_stream, StrandedInterval if is_stranded else Interval) return GenomicIntervalsStreamed(StreamNode(interval_stream), genome_context, is_stranded=is_stranded)
[docs] @abstractmethod def clip(self) -> 'GenomicIntervals': """Clip the intervals so that they are contained in the genome Returns ------- 'GenomicIntervals' Clipped intervals """ return NotImplemented
def compute(self): return NotImplemented
class GenomicIntervalsFull(GenomicIntervals): ''' Class for holding a set of intervals in memory''' is_stream = False def __init__(self, intervals: Interval, genome_context: GenomeContextBase, is_stranded=False): self._intervals = intervals self._is_stranded = is_stranded self._genome_context = genome_context @property def data(self): return self._intervals def __array_function__(self, func: callable, types: List, args: List, kwargs: Dict): if func == np.concatenate: return self.__class__(np.concatenate([obj._intervals for obj in args[0]]), self._genome_context, self._is_stranded) return NotImplemented def __repr__(self): return f'Genomic Intervals on {self._genome_context}:\n{self._intervals.astype(Interval)}' def get_data(self) -> BNPDataClass: """Return the underlying data for the intervals Returns ------- BNPDataClass The data for the intervals """ return self._intervals def __len__(self) -> int: return len(self._intervals) def map_locations(self, locations: LocationEntry): go = self._genome_context.global_offset.from_local_interval(self._intervals) global_positions = self._genome_context.global_offset.from_local_coordinates(locations.chromosome, locations.position) location_indices, interval_indices = find_indices(global_positions, go) new_entries = locations[location_indices] names = self._intervals.name if hasattr(self._intervals, 'name') else StringArray(np.arange(len(self._intervals)).astype('S')) return replace(new_entries, chromosome=names[interval_indices], position=new_entries.position - self.start[interval_indices]) return map_locations(replace(locations, position=global_positions), go) def sorted(self) -> 'GenomicIntervals': """Return the intervals sorted according to `genome_context` Returns ------- 'GenomicIntervals' """ args = np.lexsort([self.stop, self.start, self.chromosome.raw()]) return self[args] def __getitem__(self, idx): return self.__class__(self._intervals[idx], self._genome_context, self._is_stranded) def get_location(self, where: str = 'start') -> GenomicLocation: """Get the genomic location of eitert 'start', 'stop' or 'center' of the intervals Parameters ---------- where : str 'start', 'stop' or 'center' Returns ------- GenomicLocation """ if where in ('start', 'stop'): if not self.is_stranded(): data = self._intervals else: location = np.where(self.strand==('+' if where=='start' else '-'), self.start, self.stop-1) data = replace(self._intervals, start=location) else: assert where == 'center' location = (self.start+self.stop)//2 data = replace(self._intervals, start=location) return GenomicLocationGlobal.from_data( data, self._genome_context, is_stranded=self.is_stranded(), position_name='start') @property def start(self) -> int: return self._intervals.start @property def stop(self) -> int: return self._intervals.stop @property def strand(self) -> str: if not self.is_stranded(): raise ValueError('Unstranded interval has not strand') return self._intervals.strand def get_data_field(self, field_name: str): return getattr(self._intervals, field_name) @property def chromosome(self) -> str: return self._intervals.chromosome def extended_to_size(self, size: int) -> GenomicIntervals: """Extend intervals along strand to reach the given size Parameters ---------- size : int Returns ------- 'GenomicIntervals' """ chrom_sizes = self._genome_context.global_offset.get_size(self._intervals.chromosome) return self.from_intervals(extend_to_size(self._intervals, size, chrom_sizes), self._genome_context) def merged(self, distance: int = 0) -> GenomicIntervals: """Merge intervals that overlap or lie within distance of eachother Parameters ---------- distance : int Returns ------- 'GenomicIntervals' """ if distance > 0: stream = self.as_stream() return stream.merged(distance).compute() assert distance == 0, 'Distance might cross chromosome boundries so is not supported with current implementation' go = self._genome_context.global_offset global_intervals = go.from_local_interval(self._intervals) global_merged = merge_intervals(global_intervals, distance) return self.from_intervals( self._global_offset.to_local_interval(global_merged), self._genome_context) def get_pileup(self) -> GenomicArray: """Return a genmic array of counting the number of intervals covering each bp Returns ------- GenomicArray Pileup track """ go = self._genome_context.global_offset.from_local_interval(self._intervals) return GenomicArray.from_global_data( get_pileup(go, self._genome_context.size), self._genome_context) def get_mask(self) -> GenomicArray: """Return a boolean mask of areas covered by any interval Returns ------- GenomicArray Genomic mask """ I = RawInterval starts, stops = self._genome_context.global_offset.start_ends_from_intervals(self._intervals) global_mask = get_boolean_mask(I(starts, stops), self._genome_context.size) return GenomicArray.from_global_data(global_mask, self._genome_context) def clip(self) -> 'GenomicIntervalsFull': """Clip the intervals so that they are contained in the genome Returns ------- 'GenomicIntervals' Clipped intervals """ chrom_sizes = self._genome_context.global_offset.get_size(self._intervals.chromosome) return replace(self, start=np.maximum(0, self.start), stop=np.minimum(chrom_sizes, self.stop)) def __replace__(self, **kwargs): return self.__class__(replace(self._intervals, **kwargs), self._genome_context, self._is_stranded) def compute(self): return self def as_stream(self): interval_class = StrandedInterval if self._is_stranded else Interval filled = self.genome_context.iter_chromosomes(self._intervals, interval_class) return GenomicIntervalsStreamed( StreamNode(filled), self._genome_context, self._is_stranded) def get_sorted_stream(self): sorted_intervals = self.sorted() return self.from_interval_stream(iter([sorted_intervals])) def is_stranded(self): return self._is_stranded class GenomicIntervalsStreamed(GenomicIntervals, Node): ''' Class for representing intervals that are grouped by chromosome, and where only intervals for one chromosome at the time is kept in memory ''' is_stream = True def _get_chrom_size(self, intervals: Interval): return self._genome_context.chrom_sizes[intervals.chromosome] def __str__(self): return 'GIS:' + str(self._intervals_node) def __repr__(self): return 'GIS:' + str(self._intervals_node) def __init__(self, intervals_node: Node, genome_context: GenomeContextBase, is_stranded=False): self._genome_context = genome_context self._start = ComputationNode(getattr, [intervals_node, 'start']) self._stop = ComputationNode(getattr, [intervals_node, 'stop']) if is_stranded: self._strand = ComputationNode(getattr, [intervals_node, 'strand']) self._chromosome = ComputationNode(getattr, [intervals_node, 'chromosome']) self._chrom_size_node = StreamNode(iter(self._genome_context.chrom_sizes.values())) self._intervals_node = intervals_node self._is_stranded = is_stranded def is_stranded(self): return self._is_stranded def sorted(self): return NotImplemented @property def start(self): return self._start @property def stop(self): return self._stop @property def chromosome(self): return self._chromosome def get_data_field(self, field_name: str): return ComputationNode(getattr, [self._intervals_node, 'chromosome']) @property def strand(self): if not self.is_stranded(): raise ValueError('Strand not supported on unstranded intervals') return self._strand def __getitem__(self, item): return self.__class__(ComputationNode(lambda x, i: x[i], [self._intervals_node, item]), self._genome_context) def extended_to_size(self, size: int) -> GenomicIntervals: """Extend intervals along strand to reach the given size Parameters ---------- size : int Returns ------- 'GenomicIntervals' """ return self.__class__( ComputationNode(extend_to_size, [self._intervals_node, size, self._chrom_size_node]), self._genome_context) def merged(self, distance: int = 0) -> GenomicIntervals: """Merge intervals that overlap or lie within distance of eachother Parameters ---------- distance : int Returns ------- 'GenomicIntervals' 4 """ return self.__class__(ComputationNode(merge_intervals, [self._intervals_node, distance]), self._genome_context) def get_pileup(self) -> GenomicArray: """Create a GenomicTrack of how many intervals covers each position in the genome Parameters ---------- intervals : Interval Returns ------- GenomicArray """ return GenomicArrayNode(ComputationNode(get_pileup, [self._intervals_node, self._chrom_size_node]), self._genome_context) def get_mask(self) -> GenomicArray: return GenomicArrayNode(ComputationNode(get_boolean_mask, [self._intervals_node, self._chrom_size_node]), self._genome_context) def clip(self) -> 'GenomicIntervals': return self.__class__(ComputationNode(clip, [self._intervals_node, self._chrom_size_node]), self._genome_context) def __replace__(self, **kwargs): return self.__class__(ComputationNode(replace, [self._intervals_node], kwargs), self._genome_context) # return self.__class__(dataclasses.replace(self._intervals, **kwargs), self._genome_context) def compute(self): chromosome, start, stop = compute((self.chromosome, self.start, self.stop)) return GenomicIntervalsFull(Interval(chromosome, start, stop), self._genome_context) def _get_buffer(self, i): return GenomicIntervalsFull(Interval(self.chromosome._get_buffer(i), self.start._get_buffer(i), self.stop._get_buffer(i)), self._genome_context) def as_stream(self): return self def get_location(self, where: str = 'start') -> GenomicLocation: """Get the genomic location of eitert 'start', 'stop' or 'center' of the intervals Parameters ---------- where : str 'start', 'stop' or 'center' Returns ------- GenomicLocation """ assert where == 'start' and not self.is_stranded() return GenomicLocationStreamed(self._intervals_node, self._genome_context, False, {'chromosome': 'chromosome', 'position': 'start', 'strand': 'strand'})