#!/usr/bin/env python3
# Dependencies:
# pip install python-language-server
import sys
import logging
import threading
from pyls import _utils, uris
from pyls_jsonrpc.dispatchers import MethodDispatcher
from pyls_jsonrpc.endpoint import Endpoint
from pyls_jsonrpc.streams import JsonRpcStreamReader, JsonRpcStreamWriter
from pyls.workspace import Workspace
"""
Toy language server that implements textDocument/definition
For example, given this file
```smt2
(declare-const x Int)
(assert (= x 123))
```
if the cursor is on the "x" in line 3, textDocument/definition will return
the position of the x in line 1.
"""
# SMTLIBLanguageServer is adadpted from pyls
log = logging.getLogger(__name__)
PARENT_PROCESS_WATCH_INTERVAL = 10 # 10 s
MAX_WORKERS = 64
class SMTLIBLanguageServer(MethodDispatcher):
""" Implementation of the Microsoft VSCode Language Server Protocol
https://github.com/Microsoft/language-server-protocol/blob/master/versions/protocol-1-x.md
"""
def __init__(self, rx, tx, check_parent_process=False):
self.workspace = None
self.root_uri = None
self.watching_thread = None
self.workspaces = {}
self.uri_workspace_mapper = {}
self._jsonrpc_stream_reader = JsonRpcStreamReader(rx)
self._jsonrpc_stream_writer = JsonRpcStreamWriter(tx)
self._check_parent_process = check_parent_process
self._endpoint = Endpoint(
self, self._jsonrpc_stream_writer.write, max_workers=MAX_WORKERS)
self._dispatchers = []
self._shutdown = False
def start(self):
"""Entry point for the server."""
self._jsonrpc_stream_reader.listen(self._endpoint.consume)
def __getitem__(self, item):
"""Override getitem to fallback through multiple dispatchers."""
if self._shutdown and item != 'exit':
# exit is the only allowed method during shutdown
log.debug("Ignoring non-exit method during shutdown: %s", item)
raise KeyError
try:
return super(SMTLIBLanguageServer, self).__getitem__(item)
except KeyError:
# Fallback through extra dispatchers
for dispatcher in self._dispatchers:
try:
return dispatcher[item]
except KeyError:
continue
raise KeyError()
def m_shutdown(self, **_kwargs):
self._shutdown=True
return None
def m_exit(self, **_kwargs):
self._endpoint.shutdown()
self._jsonrpc_stream_reader.close()
self._jsonrpc_stream_writer.close()
def _match_uri_to_workspace(self, uri):
workspace_uri=_utils.match_uri_to_workspace(uri, self.workspaces)
return self.workspaces.get(workspace_uri, self.workspace)
def capabilities(self):
server_capabilities={
"definitionProvider": True,
}
log.info('Server capabilities: %s', server_capabilities)
return server_capabilities
def m_initialize(self, processId=None, rootUri=None, rootPath=None, initializationOptions=None, **_kwargs):
log.debug('Language server initialized with %s %s %s %s',
processId, rootUri, rootPath, initializationOptions)
if rootUri is None:
rootUri=uris.from_fs_path(
rootPath) if rootPath is not None else ''
self.workspaces.pop(self.root_uri, None)
self.root_uri = rootUri
self.workspace = Workspace(rootUri, self._endpoint, None)
self.workspaces[rootUri] = self.workspace
if self._check_parent_process and processId is not None and self.watching_thread is None:
def watch_parent_process(pid):
# exit when the given pid is not alive
if not _utils.is_process_alive(pid):
log.info("parent process %s is not alive, exiting!", pid)
self.m_exit()
else:
threading.Timer(PARENT_PROCESS_WATCH_INTERVAL,
watch_parent_process, args=[pid]).start()
self.watching_thread = threading.Thread(
target=watch_parent_process, args=(processId,))
self.watching_thread.daemon = True
self.watching_thread.start()
return {'capabilities': self.capabilities()}
def m_initialized(self, **_kwargs):
pass
def m_text_document__definition(self, textDocument=None, position=None, **_kwargs):
doc_uri = textDocument["uri"]
workspace = self._match_uri_to_workspace(doc_uri)
doc = workspace.get_document(doc_uri) if doc_uri else None
return smt_definition(doc, position)
def m_text_document__did_close(self, textDocument=None, **_kwargs):
pass
def m_text_document__did_open(self, textDocument=None, **_kwargs):
pass
def m_text_document__did_change(self, contentChanges=None, textDocument=None, **_kwargs):
pass
def m_text_document__did_save(self, textDocument=None, **_kwargs):
pass
def m_text_document__completion(self, textDocument=None, **_kwargs):
pass
def flatten(list_of_lists):
return [item for lst in list_of_lists for item in lst]
def merge(list_of_dicts):
return {k: v for dictionary in list_of_dicts for k, v in dictionary.items()}
def smt_definition(document, position):
pos = definition(document.source, position["line"], position["character"])
if pos is None:
return None
line, col, token = pos
offset = 1 if len(token) == 1 else (len(token) + 1)
if col == 0:
line -= 1
col = len(document.lines[line]) - offset
else:
col = col - offset
return {
'uri': document.uri,
'range': {
'start': {'line': line, 'character': col},
'end': {'line': line, 'character': col},
}
}
def definition(source, cursor_line, cursor_character):
nodes = list(parser().parse_smtlib(source))
node_at_cursor = find_leaf_node_at(cursor_line, cursor_character, nodes)
line, col, node = find_definition_for(node_at_cursor, nodes)
if node is None:
return None
return line, col, node_at_cursor
def find_leaf_node_at(line, col, nodes):
prev_line_end = -1
prev_col_end = -1
needle = (line, col)
for line_end, col_end, node in nodes:
prev_range = (prev_line_end-1, prev_col_end)
cur_range = (line_end, col_end)
if prev_range < needle < cur_range:
if isinstance(node, str):
return node
else:
node_at = find_leaf_node_at(line, col, node)
assert node_at is not None
return node_at
prev_line_end = line_end
prev_col_end = col_end
return None
def stripprefix(x, prefix):
if x.startswith(prefix):
return x[len(prefix):]
return x
def find_definition_for(needle, nodes):
for node in nodes:
line_end, col_end, n = node
_, _, head = n[0]
if not head.startswith("declare-") and not head.startswith("define-"):
continue
_, _, symbol = n[1]
if head in ("declare-const", "define-const", "declare-fun", "define-fun", "define-fun-rec", "declare-datatype"):
if symbol == needle:
return n[1]
continue
if head in ("declare-datatypes", "define-funs-rec"):
for i, tmp in enumerate(symbol):
_, _, type_parameter_declaration = tmp
_, _, type_name = type_parameter_declaration[0]
if type_name == needle:
return type_parameter_declaration[0]
if head == "declare-datatypes":
constructor = dfs(needle, node)
if constructor is not None:
return constructor
constructor = dfs(stripprefix(needle, "is-"), node)
if constructor is not None:
return constructor
continue
assert f"unsupported form: {head}"
return -1, -1, None
def dfs(needle, node):
assert isinstance(node, tuple)
_, _, n = node
if isinstance(n, str):
if n == needle:
return node
else:
return None
for child in n:
found = dfs(needle, child)
if found is not None:
return found
return None
class parser:
def __init__(self):
self.pos = 0
self.line = 0
self.col = -1
self.text = None
def nextch(self):
char = self.text[self.pos]
self.pos += 1
self.col += 1
if char == "\n":
self.line += 1
self.col = 0
return char
def parse_smtlib(self, text):
assert self.text is None
self.text = text
return self.parse_smtlib_aux()
def parse_smtlib_aux(self):
exprs = []
cur_expr = None
size = len(self.text)
while self.pos < size:
char = self.nextch()
# Stolen from ddSMT's parser. Not fully SMT-LIB compliant but good enough.
# String literals/quoted symbols
if char in ('"', '|'):
first_char = char
literal = [char]
# Read until terminating " or |
while True:
if self.pos >= size:
return
char = self.nextch()
literal.append(char)
if char == first_char:
# Check is quote is escaped "a "" b" is one string literal
if char == '"' and self.pos < size and self.text[self.pos] == '"':
literal.append(self.text[self.pos])
self.nextch()
continue
break
cur_expr.append((self.line, self.col, literal))
continue
# Comments
if char == ';':
# Read until newline
while self.pos < size:
char = self.nextch()
if char == '\n':
break
continue
# Open s-expression
if char == '(':
cur_expr = []
exprs.append(cur_expr)
continue
# Close s-expression
if char == ')':
cur_expr = exprs.pop()
# Do we have nested s-expressions?
if exprs:
exprs[-1].append((self.line, self.col, cur_expr))
cur_expr = exprs[-1]
else:
yield self.line, self.col, cur_expr
cur_expr = None
continue
# Identifier
if char not in (' ', '\t', '\n'):
token = [char]
while True:
if self.pos >= size:
return
char = self.text[self.pos]
if char in ('(', ')', ';'):
break
self.nextch()
if char in (' ', '\t', '\n'):
break
token.append(char)
token = ''.join(token)
# Append to current s-expression
if cur_expr is not None:
cur_expr.append((self.line, self.col, token))
else:
yield self.line, self.col, token
def serve():
stdin = sys.stdin.buffer
stdout = sys.stdout.buffer
server = SMTLIBLanguageServer(stdin, stdout)
server.start()
if __name__ == "__main__":
if len(sys.argv) >= 2 and sys.argv[1] == "definition":
line = int(sys.argv[2])
col = int(sys.argv[3])
print(definition(sys.stdin.read(), line, col))
else:
serve()