codon/stdlib/re.codon

453 lines
13 KiB
Python

# Copyright (C) 2022-2024 Exaloop Inc. <https://exaloop.io>
# Adapted in part from Google's Python re2 wrapper
# https://github.com/google/re2/blob/abseil/python/re2.py
A = (1 << 0)
ASCII = (1 << 0)
DEBUG = (1 << 1)
I = (1 << 2)
IGNORECASE = (1 << 2)
L = (1 << 3)
LOCALE = (1 << 3)
M = (1 << 4)
MULTILINE = (1 << 4)
S = (1 << 5)
DOTALL = (1 << 5)
X = (1 << 6)
VERBOSE = (1 << 6)
_ANCHOR_NONE = 0
_ANCHOR_START = 1
_ANCHOR_BOTH = 2
@tuple
class Span:
start: int
end: int
def __bool__(self):
return not (self.start == -1 and self.end == -1)
@C
@pure
def seq_re_match(re: cobj,
anchor: int,
string: str,
pos: int,
endpos: int) -> Ptr[Span]:
pass
@C
@pure
def seq_re_match_one(re: cobj,
anchor: int,
string: str,
pos: int,
endpos: int) -> Span:
pass
@C
@pure
def seq_re_pattern_groups(re: cobj) -> int:
pass
@C
@pure
def seq_re_group_name_to_index(re: cobj, name: str) -> int:
pass
@C
@pure
def seq_re_group_index_to_name(re: cobj, index: int) -> str:
pass
@C
@pure
def seq_re_pattern_error(re: cobj) -> str:
pass
@C
@pure
def seq_re_escape(pattern: str) -> str:
pass
@C
def seq_re_purge() -> None:
pass
@C
@pure
def seq_re_compile(pattern: str, flags: int) -> cobj:
pass
class error(Static[Exception]):
pattern: str
def __init__(self, message: str = "", pattern: str = ""):
super().__init__("re.error", message)
self.pattern = pattern
@property
def msg(self):
return self.message
@tuple
class Pattern:
pattern: str
flags: int
_re: cobj
def compile(pattern: str, flags: int = 0):
re = seq_re_compile(pattern, flags)
err_msg = seq_re_pattern_error(re)
if err_msg:
raise error(err_msg, pattern)
return Pattern(pattern, flags, re)
def search(pattern: str, string: str, flags: int = 0):
return compile(pattern, flags).search(string)
def match(pattern: str, string: str, flags: int = 0):
return compile(pattern, flags).match(string)
def fullmatch(pattern: str, string: str, flags: int = 0):
return compile(pattern, flags).fullmatch(string)
def finditer(pattern: str, string: str, flags: int = 0):
return compile(pattern, flags).finditer(string)
def findall(pattern: str, string: str, flags: int = 0):
return compile(pattern, flags).findall(string)
def split(pattern: str, string: str, maxsplit: int = 0, flags: int = 0):
return compile(pattern, flags).split(string, maxsplit)
def sub(pattern: str, repl, string: str, count: int = 0, flags: int = 0):
return compile(pattern, flags).sub(repl, string, count)
def subn(pattern: str, repl, string: str, count: int = 0, flags: int = 0):
return compile(pattern, flags).subn(repl, string, count)
def escape(pattern: str):
return seq_re_escape(pattern)
def purge():
seq_re_purge()
@tuple
class Match:
_spans: Ptr[Span]
pos: int
endpos: int
re: Pattern
string: str
def _get_group_int(self, g: int, n: int):
if not (0 <= g <= n):
raise IndexError("no such group")
return self._spans[g]
def _get_group_str(self, g: str, n: int):
return self._get_group_int(seq_re_group_name_to_index(self.re._re, g), n)
def _get_group(self, g, n: int):
if isinstance(g, int):
return self._get_group_int(g, n)
elif isinstance(g, str):
return self._get_group_str(g, n)
else:
return self._get_group(g.__index__(), n)
def _span_match(self, span: Span):
if not span:
return None
return self.string._slice(span.start, span.end)
def _get_match(self, g, n: int):
span = self._get_group(g, n)
return self._span_match(span)
def _group_multi(self, n: int, *args):
if staticlen(args) == 1:
return (self._get_match(args[0], n),)
else:
return (self._get_match(args[0], n), *self._group_multi(n, *args[1:]))
def group(self, *args):
if staticlen(args) == 0:
return self._get_match(0, 1).__val__()
elif staticlen(args) == 1:
return self._get_match(args[0], self.re.groups)
else:
return self._group_multi(self.re.groups, *args)
def __getitem__(self, g):
return self._get_match(g, self.re.groups)
def start(self, group = 0):
return self._get_group(group, self.re.groups).start
def end(self, group = 0):
return self._get_group(group, self.re.groups).end
def span(self, group = 0):
start, end = self._get_group(group, self.re.groups)
return start, end
def _split(template: str):
backslash = '\\'
pieces = ['']
index = template.find(backslash)
OCTAL = compile(r'\\[0-7][0-7][0-7]')
GROUP = compile(r'\\[1-9][0-9]?|\\g<\w+>')
while index != -1:
piece, template = template[:index], template[index:]
pieces[-1] += piece
octal_match = OCTAL.match(template)
group_match = GROUP.match(template)
if (not octal_match) and group_match:
index = group_match.end()
piece, template = template[:index], template[index:]
pieces.extend((piece, ''))
else:
index = 2
piece, template = template[:index], template[index:]
pieces[-1] += piece
index = template.find(backslash)
pieces[-1] += template
return pieces
def _unescape(s: str):
r = []
n = len(s)
i = 0
while i < n:
if s[i] == '\\' and i + 1 < n:
c = s[i + 1]
if c == 'a':
r.append('\a')
i += 1
elif c == 'b':
r.append('\b')
i += 1
elif c == 'f':
r.append('\f')
i += 1
elif c == 'n':
r.append('\n')
i += 1
elif c == 'r':
r.append('\r')
i += 1
elif c == 't':
r.append('\t')
i += 1
elif c == 'v':
r.append('\v')
i += 1
elif c == '"':
r.append('\"')
i += 1
elif c == '\'':
r.append('\'')
i += 1
elif c == '\\':
r.append('\\')
i += 1
elif '0' <= c <= '7':
k = i + 2
while k < n and k - i <= 4 and '0' <= s[k] <= '7':
k += 1
code = int(s[i+1:k], 8)
p = Ptr[byte](1)
p[0] = byte(code)
r.append(str(p, 1))
i = k - 1
elif c.isalpha():
raise error(f"bad escape \\{c} at position {i}")
else:
r.append(s[i])
else:
r.append(s[i])
i += 1
return str.cat(r)
def expand(self, template: str):
def get_or_empty(s: Optional[str]):
return s if s is not None else ''
pieces = list(Match._split(template))
INT = compile(r'[+-]?\d+')
for index, piece in enumerate(pieces):
if not (index % 2):
pieces[index] = Match._unescape(piece)
else:
if len(piece) <= 3:
pieces[index] = get_or_empty(self[int(piece[1:])])
else:
group = piece[3:-1]
if INT.fullmatch(group):
pieces[index] = get_or_empty(self[int(group)])
else:
pieces[index] = get_or_empty(self[group])
return str.cat(pieces)
@property
def lastindex(self):
max_end = -1
max_group = None
for group in range(1, self.re.groups + 1):
end = self._spans[group].end
if max_end < end:
max_end = end
max_group = group
return max_group
@property
def lastgroup(self):
max_group = self.lastindex
if max_group is None:
return None
return seq_re_group_index_to_name(self.re._re, max_group)
def groups(self, default: Optional[str] = None):
def get_or_default(item, default):
return item if item is not None else default
n = self.re.groups
return [get_or_default(self._span_match(self._spans[i]), default)
for i in range(1, n + 1)]
def groupdict(self, default: Optional[str] = None):
d = {}
for group, index in self.re.groupindex.items():
item = self[index]
d[group] = item if item is not None else default
return d
def __copy__(self):
return self
def __deepcopy__(self):
return self
def __bool__(self):
return True
@extend
class Pattern:
@property
def groups(self):
return seq_re_pattern_groups(self._re)
@property
def groupindex(self):
d = {}
for i in range(1, self.groups + 1):
name = seq_re_group_index_to_name(self._re, i)
if name:
d[name] = i
return d
def _match_one(self, anchor: int, string: str, pos: Optional[int], endpos: Optional[int]):
posx = 0 if pos is None else max(0, min(pos.__val__(), len(string)))
endposx = len(string) if endpos is None else max(0, min(endpos.__val__(), len(string)))
if posx > endposx:
return None
spans = seq_re_match(self._re, anchor, string, posx, endposx)
if not spans[0]:
return None
return Match(spans, posx, endposx, self, string)
def _match(self, anchor: int, string: str, pos: Optional[int], endpos: Optional[int]):
posx = 0 if pos is None else max(0, min(pos.__val__(), len(string)))
endposx = len(string) if endpos is None else max(0, min(endpos.__val__(), len(string)))
if posx > endposx:
return
while True:
spans = seq_re_match(self._re, anchor, string, posx, endposx)
if not spans[0]:
break
yield Match(spans, posx, endposx, self, string)
if posx == endposx:
break
elif posx == spans[0][1]:
# We matched the empty string at pos and would be stuck, so in order
# to make forward progress, increment the bytes offset.
posx += 1
else:
posx = spans[0][1]
def search(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None):
return self._match_one(_ANCHOR_NONE, string, pos, endpos)
def match(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None):
return self._match_one(_ANCHOR_START, string, pos, endpos)
def fullmatch(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None):
return self._match_one(_ANCHOR_BOTH, string, pos, endpos)
def finditer(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None):
return self._match(_ANCHOR_NONE, string, pos, endpos)
def findall(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None):
return [m.group() for m in self.finditer(string, pos, endpos)]
def _split(self, cb, string: str, maxsplit: int = 0, T: type = str):
if maxsplit < 0:
return [T(string)], 0
pieces: List[T] = []
end = 0
numsplit = 0
for match in self.finditer(string):
if (maxsplit > 0 and numsplit >= maxsplit):
break
pieces.append(string[end:match.start()])
pieces.extend(cb(match))
end = match.end()
numsplit += 1
pieces.append(string[end:])
return pieces, numsplit
def split(self, string: str, maxsplit: int = 0):
cb = lambda match: [match[group] for group in range(1, self.groups + 1)]
pieces, _ = self._split(cb, string, maxsplit, Optional[str])
return pieces
def _repl(match, repl):
if isinstance(repl, str):
return match.expand(repl)
else:
return repl(match)
def subn(self, repl, string: str, count: int = 0):
cb = lambda match: [Pattern._repl(match, repl)]
pieces, numsplit = self._split(cb, string, count, str)
joined_pieces = str.cat(pieces)
return joined_pieces, numsplit
def sub(self, repl, string: str, count: int = 0):
joined_pieces, _ = self.subn(repl, string, count)
return joined_pieces
def __bool__(self):
return True