Giter Site home page Giter Site logo

Comments (1)

sweep-nightly avatar sweep-nightly commented on June 15, 2024

🚀 Here's the PR! #3841

💎 Sweep Pro: You have unlimited Sweep issues

Actions

  • ↻ Restart Sweep

Step 1: 🔎 Searching

Here are the code search results. I'm now analyzing these search results to write the PR.

Relevant files (click to expand). Mentioned files will always appear here.

import json
import multiprocessing
import os
from typing import Generator
import backoff
import numpy as np
import openai
import requests
from loguru import logger
from redis import Redis
from tqdm import tqdm
import voyageai
import boto3
from botocore.exceptions import ClientError
from voyageai import error as voyageai_error
from sweepai.config.server import BATCH_SIZE, REDIS_URL, VOYAGE_API_AWS_ENDPOINT_NAME, VOYAGE_API_KEY, VOYAGE_API_USE_AWS
from sweepai.utils.hash import hash_sha256
from sweepai.utils.openai_proxy import get_embeddings_client
from sweepai.utils.utils import Tiktoken
# Now uses Voyage AI if available, with asymmetric embedding
# CACHE_VERSION = "v2.0.04" + "-voyage" if VOYAGE_API_KEY else ""
suffix = "-voyage-aws" if VOYAGE_API_USE_AWS else "-voyage" if VOYAGE_API_KEY else ""
CACHE_VERSION = "v2.0.08" + suffix
redis_client: Redis = Redis.from_url(REDIS_URL) # TODO: add lazy loading
tiktoken_client = Tiktoken()
def cosine_similarity(a, B):
"""
Updated to handle multi-queries.
"""
dot_product = np.dot(B, a.T) # B is MxN, a.T is Nxq, resulting in Mxq
norm_a = np.linalg.norm(a, axis=1)
norm_B = np.linalg.norm(B, axis=1)
dot_product /= norm_a
dot_product = dot_product.T / norm_B
return dot_product
def chunk(texts: list[str], batch_size: int) -> Generator[list[str], None, None]:
logger.info(f"Truncating {len(texts)} texts")
texts = [text[:25000] if len(text) > 25000 else text for text in texts]
# remove empty string
texts = [text if text else " " for text in texts]
logger.info(f"Finished truncating {len(texts)} texts")
for i in range(0, len(texts), batch_size):
yield texts[i : i + batch_size] if i + batch_size < len(texts) else texts[i:]
# @file_cache(ignore_params=["texts"])
def multi_get_query_texts_similarity(queries: list[str], documents: list[str]) -> list[float]:
if not documents:
return []
embeddings = embed_text_array(documents)
embeddings = np.concatenate(embeddings)
query_embedding = np.array(openai_call_embedding(queries, input_type="query"))
similarity = cosine_similarity(query_embedding, embeddings)
similarity = similarity.tolist()
return similarity
def normalize_l2(x):
x = np.array(x)
if x.ndim == 1:
norm = np.linalg.norm(x)
if norm == 0:
return x
return x / norm
else:
norm = np.linalg.norm(x, 2, axis=1, keepdims=True)
return np.where(norm == 0, x, x / norm)
def batch_by_token_count_for_voyage(
texts: list[str],
max_tokens: int = 120_000,
max_length: int = 128,
) -> list[list[str]]:
"""
This function splits the texts into batches based on the token count.
Max token count for Voyage is 120k and max batch length count is 128.
"""
client = voyageai.Client()
batches = []
batch = []
token_count = 0
for text in texts:
text_token_count = client.count_tokens([text])
if token_count + text_token_count > max_tokens * 0.95 or len(batch) >= max_length:
batches.append(batch)
batch = [text] # Start the new batch with the current text
token_count = text_token_count # Reset token count for the new batch
else:
batch.append(text)
token_count += text_token_count
if batch:
batches.append(batch)
del client
return batches
# lru_cache(maxsize=20)
# @redis_cache()
def embed_text_array(texts: list[str]) -> list[np.ndarray]:
embeddings = []
texts = [text if text else " " for text in texts]
batches = [texts[i : i + BATCH_SIZE] for i in range(0, len(texts), BATCH_SIZE)]
workers = min(max(1, multiprocessing.cpu_count() // 4), 1)
if workers > 1:
with multiprocessing.Pool(
processes=workers
) as pool:
embeddings = list(
tqdm(
pool.imap(openai_with_expo_backoff, batches),
total=len(batches),
desc="openai embedding",
)
)
else:
embeddings = [openai_with_expo_backoff(batch) for batch in tqdm(batches, desc="openai embedding")]
return embeddings
# @redis_cache()
def openai_call_embedding_router(batch: list[str], input_type: str="document"): # input_type can be query or document
VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY", None)
VOYAGE_API_AWS_ACCESS_KEY = os.environ.get("VOYAGE_API_AWS_ACCESS_KEY", None)
VOYAGE_API_AWS_SECRET_KEY = os.environ.get("VOYAGE_API_AWS_SECRET_KEY", None)
VOYAGE_API_AWS_REGION = os.environ.get("VOYAGE_API_AWS_REGION", None)
VOYAGE_API_USE_AWS = VOYAGE_API_AWS_ACCESS_KEY and VOYAGE_API_AWS_SECRET_KEY and VOYAGE_API_AWS_REGION
if len(batch) == 0:
return np.array([])
if VOYAGE_API_USE_AWS:
sm_runtime = boto3.client(
"sagemaker-runtime",
aws_access_key_id=VOYAGE_API_AWS_ACCESS_KEY,
aws_secret_access_key=VOYAGE_API_AWS_SECRET_KEY,
region_name=VOYAGE_API_AWS_REGION
)
input_json = json.dumps({
"input": batch,
"input_type": input_type,
"truncation": "true"
})
response = sm_runtime.invoke_endpoint(
EndpointName=VOYAGE_API_AWS_ENDPOINT_NAME,
ContentType="application/json",
Accept="application/json",
Body=input_json,
)
body = response["Body"]
obj = json.load(body)
data = obj["data"]
return np.array([vector["embedding"] for vector in data])
elif VOYAGE_API_KEY:
client = voyageai.Client(api_key=VOYAGE_API_KEY)
result = client.embed(batch, model="voyage-code-2", input_type=input_type, truncation=True)
cut_dim = np.array([data for data in result.embeddings])
normalized_dim = normalize_l2(cut_dim)
del client
return normalized_dim
else:
client = get_embeddings_client()
response = client.embeddings.create(
input=batch, model="text-embedding-3-small", encoding_format="float"
)
cut_dim = np.array([data.embedding for data in response.data])[:, :512]
normalized_dim = normalize_l2(cut_dim)
# save results to redis
return normalized_dim
def openai_call_embedding(batch: list[str], input_type: str="document"):
# Backoff on batch size by splitting the batch in half.
try:
return openai_call_embedding_router(batch, input_type)
except (voyageai_error.InvalidRequestError, ClientError) as e: # full error is botocore.errorfactory.ModelError: but I can't find it
if len(batch) > 1 and "Please lower the number of tokens in the batch." in str(e):
logger.error(f"Token count exceeded for batch: {max([tiktoken_client.count(text) for text in batch])} retrying by splitting batch in half.")
mid = len(batch) // 2
left = openai_call_embedding(batch[:mid], input_type)
right = openai_call_embedding(batch[mid:], input_type)
return np.concatenate((left, right))
else:
raise e
except openai.BadRequestError as e:
# In the future we can better handle this by averaging the embeddings of the split batch
if "This model's maximum context length" in str(e):
logger.warning(f"Token count exceeded for batch: {max([tiktoken_client.count(text) for text in batch])} truncating down to 8192 tokens.")
batch = [tiktoken_client.truncate_string(text) for text in batch]
return openai_call_embedding(batch, input_type)
@backoff.on_exception(
backoff.expo,
requests.exceptions.Timeout,
max_tries=5,
)
def openai_with_expo_backoff(batch: tuple[str]):
if not redis_client:
return openai_call_embedding(batch)
# check cache first
embeddings = [None] * len(batch)
cache_keys = [hash_sha256(text) + CACHE_VERSION for text in batch]
try:
for i, cache_value in enumerate(redis_client.mget(cache_keys)):
if cache_value:
embeddings[i] = np.array(json.loads(cache_value))
except Exception as e:
logger.exception(e)
# not stored in cache call openai
batch = [
text for i, text in enumerate(batch) if embeddings[i] is None
] # remove all the cached values from the batch
if len(batch) == 0:
embeddings = np.array(embeddings)
return embeddings # all embeddings are in cache
try:
# make sure all token counts are within model params (max: 8192)
new_embeddings = openai_call_embedding(batch)
except requests.exceptions.Timeout as e:
logger.exception(f"Timeout error occured while embedding: {e}")
except Exception as e:
logger.exception(e)
if any(tiktoken_client.count(text) > 8192 for text in batch):
logger.warning(
f"Token count exceeded for batch: {max([tiktoken_client.count(text) for text in batch])} truncating down to 8192 tokens."
)
batch = [tiktoken_client.truncate_string(text) for text in batch]
new_embeddings = openai_call_embedding(batch)
else:
raise e
# get all indices where embeddings are None
indices = [i for i, emb in enumerate(embeddings) if emb is None]
# store the new embeddings in the correct position
assert len(indices) == len(new_embeddings)
for i, index in enumerate(indices):
embeddings[index] = new_embeddings[i]
# store in cache
try:
redis_client.mset(
{
cache_key: json.dumps(embedding.tolist())
for cache_key, embedding in zip(cache_keys, embeddings)
}
)
embeddings = np.array(embeddings)
except Exception:
# logger.error(str(e))
# logger.error("Failed to store embeddings in cache, returning without storing")
pass
return embeddings
if __name__ == "__main__":
texts = ["sasxtt " * 10000 for i in range(10)] + ["abb " * 1 for i in range(10)]

import base64
import os
from dotenv import load_dotenv
from loguru import logger
logger.print = logger.info
load_dotenv(dotenv_path=".env", override=True, verbose=True)
os.environ["GITHUB_APP_PEM"] = os.environ.get("GITHUB_APP_PEM") or base64.b64decode(
os.environ.get("GITHUB_APP_PEM_BASE64", "")
).decode("utf-8")
if os.environ["GITHUB_APP_PEM"]:
os.environ["GITHUB_APP_ID"] = (
(os.environ.get("GITHUB_APP_ID") or os.environ.get("APP_ID"))
.replace("\\n", "\n")
.strip('"')
)
TEST_BOT_NAME = "sweep-nightly[bot]"
ENV = os.environ.get("ENV", "dev")
BOT_TOKEN_NAME = "bot-token"
# goes under Modal 'discord' secret name (optional, can leave env var blank)
DISCORD_WEBHOOK_URL = os.environ.get("DISCORD_WEBHOOK_URL")
DISCORD_MEDIUM_PRIORITY_URL = os.environ.get("DISCORD_MEDIUM_PRIORITY_URL")
DISCORD_LOW_PRIORITY_URL = os.environ.get("DISCORD_LOW_PRIORITY_URL")
DISCORD_FEEDBACK_WEBHOOK_URL = os.environ.get("DISCORD_FEEDBACK_WEBHOOK_URL")
SWEEP_HEALTH_URL = os.environ.get("SWEEP_HEALTH_URL")
DISCORD_STATUS_WEBHOOK_URL = os.environ.get("DISCORD_STATUS_WEBHOOK_URL")
# goes under Modal 'github' secret name
GITHUB_APP_ID = os.environ.get("GITHUB_APP_ID", os.environ.get("APP_ID"))
# deprecated: old logic transfer so upstream can use this
if GITHUB_APP_ID is None:
if ENV == "prod":
GITHUB_APP_ID = "307814"
elif ENV == "dev":
GITHUB_APP_ID = "324098"
elif ENV == "staging":
GITHUB_APP_ID = "327588"
GITHUB_BOT_USERNAME = os.environ.get("GITHUB_BOT_USERNAME")
# deprecated: left to support old logic
if not GITHUB_BOT_USERNAME:
if ENV == "prod":
GITHUB_BOT_USERNAME = "sweep-ai[bot]"
elif ENV == "dev":
GITHUB_BOT_USERNAME = "sweep-nightly[bot]"
elif ENV == "staging":
GITHUB_BOT_USERNAME = "sweep-canary[bot]"
elif not GITHUB_BOT_USERNAME.endswith("[bot]"):
GITHUB_BOT_USERNAME = GITHUB_BOT_USERNAME + "[bot]"
GITHUB_LABEL_NAME = os.environ.get("GITHUB_LABEL_NAME", "sweep")
GITHUB_LABEL_COLOR = os.environ.get("GITHUB_LABEL_COLOR", "9400D3")
GITHUB_LABEL_DESCRIPTION = os.environ.get(
"GITHUB_LABEL_DESCRIPTION", "Sweep your software chores"
)
GITHUB_APP_PEM = os.environ.get("GITHUB_APP_PEM")
GITHUB_APP_PEM = GITHUB_APP_PEM or os.environ.get("PRIVATE_KEY")
if GITHUB_APP_PEM is not None:
GITHUB_APP_PEM = GITHUB_APP_PEM.strip(' \n"') # Remove whitespace and quotes
GITHUB_APP_PEM = GITHUB_APP_PEM.replace("\\n", "\n")
GITHUB_CONFIG_BRANCH = os.environ.get("GITHUB_CONFIG_BRANCH", "sweep/add-sweep-config")
GITHUB_DEFAULT_CONFIG = os.environ.get(
"GITHUB_DEFAULT_CONFIG",
"""# Sweep AI turns bugs & feature requests into code changes (https://sweep.dev)
# For details on our config file, check out our docs at https://docs.sweep.dev/usage/config
# This setting contains a list of rules that Sweep will check for. If any of these rules are broken in a new commit, Sweep will create an pull request to fix the broken rule.
rules:
{additional_rules}
# This is the branch that Sweep will develop from and make pull requests to. Most people use 'main' or 'master' but some users also use 'dev' or 'staging'.
branch: 'main'
# By default Sweep will read the logs and outputs from your existing Github Actions. To disable this, set this to false.
gha_enabled: True
# This is the description of your project. It will be used by sweep when creating PRs. You can tell Sweep what's unique about your project, what frameworks you use, or anything else you want.
#
# Example:
#
# description: sweepai/sweep is a python project. The main api endpoints are in sweepai/api.py. Write code that adheres to PEP8.
description: ''
# This sets whether to create pull requests as drafts. If this is set to True, then all pull requests will be created as drafts and GitHub Actions will not be triggered.
draft: False
# This is a list of directories that Sweep will not be able to edit.
blocked_dirs: []
""",
)
MONGODB_URI = os.environ.get("MONGODB_URI", None)
IS_SELF_HOSTED = os.environ.get("IS_SELF_HOSTED", "true").lower() == "true"
REDIS_URL = os.environ.get("REDIS_URL")
if not REDIS_URL:
REDIS_URL = os.environ.get("redis_url", "redis://0.0.0.0:6379/0")
ORG_ID = os.environ.get("ORG_ID", None)
POSTHOG_API_KEY = os.environ.get(
"POSTHOG_API_KEY", "phc_CnzwIB0W548wN4wEGeRuxXqidOlEUH2AcyV2sKTku8n"
)
SUPPORT_COUNTRY = os.environ.get("GDRP_LIST", "").split(",")
WHITELISTED_REPOS = os.environ.get("WHITELISTED_REPOS", "").split(",")
BLACKLISTED_USERS = os.environ.get("BLACKLISTED_USERS", "").split(",")
# Default OpenAI
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None) # this may be none, and it will use azure
OPENAI_API_TYPE = os.environ.get("OPENAI_API_TYPE", "anthropic")
assert OPENAI_API_TYPE in ["anthropic", "azure", "openai"], "Invalid OPENAI_API_TYPE"
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
AZURE_API_KEY = os.environ.get("AZURE_API_KEY", None)
OPENAI_API_BASE = os.environ.get("OPENAI_API_BASE", None)
OPENAI_API_VERSION = os.environ.get("OPENAI_API_VERSION", None)
AZURE_OPENAI_DEPLOYMENT = os.environ.get("AZURE_OPENAI_DEPLOYMENT", None)
OPENAI_EMBEDDINGS_API_TYPE = os.environ.get("OPENAI_EMBEDDINGS_API_TYPE", "openai")
OPENAI_EMBEDDINGS_AZURE_ENDPOINT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_ENDPOINT", None
)
OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_DEPLOYMENT", None
)
OPENAI_EMBEDDINGS_AZURE_API_VERSION = os.environ.get(
"OPENAI_EMBEDDINGS_AZURE_API_VERSION", None
)
OPENAI_API_ENGINE_GPT35 = os.environ.get("OPENAI_API_ENGINE_GPT35", None)
OPENAI_API_ENGINE_GPT4 = os.environ.get("OPENAI_API_ENGINE_GPT4", None)
MULTI_REGION_CONFIG = os.environ.get("MULTI_REGION_CONFIG", None)
if isinstance(MULTI_REGION_CONFIG, str):
MULTI_REGION_CONFIG = MULTI_REGION_CONFIG.strip("'").replace("\\n", "\n")
MULTI_REGION_CONFIG = [item.split(",") for item in MULTI_REGION_CONFIG.split("\n")]
WHITELISTED_USERS = os.environ.get("WHITELISTED_USERS", None)
if WHITELISTED_USERS:
WHITELISTED_USERS = WHITELISTED_USERS.split(",")
WHITELISTED_USERS.append(GITHUB_BOT_USERNAME)
DEFAULT_GPT4_MODEL = os.environ.get("DEFAULT_GPT4_MODEL", "gpt-4-0125-preview")
RESEND_API_KEY = os.environ.get("RESEND_API_KEY", None)
LOKI_URL = None
DEBUG = os.environ.get("DEBUG", "false").lower() == "true"
ENV = "prod" if GITHUB_BOT_USERNAME != TEST_BOT_NAME else "dev"
PROGRESS_BASE_URL = os.environ.get(
"PROGRESS_BASE_URL", "https://progress.sweep.dev"
).rstrip("/")
DISABLED_REPOS = os.environ.get("DISABLED_REPOS", "").split(",")
GHA_AUTOFIX_ENABLED: bool = os.environ.get("GHA_AUTOFIX_ENABLED", False)
MERGE_CONFLICT_ENABLED: bool = os.environ.get("MERGE_CONFLICT_ENABLED", False)
INSTALLATION_ID = os.environ.get("INSTALLATION_ID", None)
AWS_ACCESS_KEY=os.environ.get("AWS_ACCESS_KEY")
AWS_SECRET_KEY=os.environ.get("AWS_SECRET_KEY")
AWS_REGION=os.environ.get("AWS_REGION")
ANTHROPIC_AVAILABLE = AWS_ACCESS_KEY and AWS_SECRET_KEY and AWS_REGION
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)
COHERE_API_KEY = os.environ.get("COHERE_API_KEY", None)
VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY", None)
VOYAGE_API_AWS_ACCESS_KEY=os.environ.get("VOYAGE_API_AWS_ACCESS_KEY_ID")
VOYAGE_API_AWS_SECRET_KEY=os.environ.get("VOYAGE_API_AWS_SECRET_KEY")
VOYAGE_API_AWS_REGION=os.environ.get("VOYAGE_API_AWS_REGION")
VOYAGE_API_AWS_ENDPOINT_NAME=os.environ.get("VOYAGE_API_AWS_ENDPOINT_NAME", "voyage-code-2")
VOYAGE_API_USE_AWS = VOYAGE_API_AWS_ACCESS_KEY and VOYAGE_API_AWS_SECRET_KEY and VOYAGE_API_AWS_REGION
PAREA_API_KEY = os.environ.get("PAREA_API_KEY", None)
# TODO: we need to make this dynamic + backoff
BATCH_SIZE = int(
os.environ.get("BATCH_SIZE", 64 if VOYAGE_API_KEY else 256) # Voyage only allows 128 items per batch and 120000 tokens per batch
)
DEPLOYMENT_GHA_ENABLED = os.environ.get("DEPLOYMENT_GHA_ENABLED", "true").lower() == "true"
JIRA_USER_NAME = os.environ.get("JIRA_USER_NAME", None)
JIRA_API_TOKEN = os.environ.get("JIRA_API_TOKEN", None)
JIRA_URL = os.environ.get("JIRA_URL", None)
SLACK_API_KEY = os.environ.get("SLACK_API_KEY", None)
LICENSE_KEY = os.environ.get("LICENSE_KEY", None)
ALTERNATE_AWS = os.environ.get("ALTERNATE_AWS", "none").lower() == "true"
WEBHOOK_SECRET = os.environ.get("WEBHOOK_SECRET", None)

Step 2: ⌨️ Coding

sweepai/core/vector_db.py

Add the necessary imports for the backoff and Redis timeout functionality.
--- 
+++ 
@@ -9,6 +9,7 @@
 import requests
 from loguru import logger
 from redis import Redis
+from redis.exceptions import TimeoutError
 from tqdm import tqdm
 import voyageai
 import boto3

sweepai/core/vector_db.py

Modify the `openai_with_expo_backoff` function to add a timeout to the Redis query and wrap it with the backoff decorator.
--- 
+++ 
@@ -1,11 +1,21 @@
     cache_keys = [hash_sha256(text) + CACHE_VERSION for text in batch]
-    try:
-        for i, cache_value in enumerate(redis_client.mget(cache_keys)):
-            if cache_value:
-                embeddings[i] = np.array(json.loads(cache_value))
-    except Exception as e:
-        logger.exception(e)
-    # not stored in cache call openai
+
+    @backoff.on_exception(backoff.expo, TimeoutError, max_tries=5)
+    def get_cached_embeddings():
+        try:
+            cache_values = redis_client.mget(cache_keys, timeout=5)
+            for i, cache_value in enumerate(cache_values):
+                if cache_value:
+                    embeddings[i] = np.array(json.loads(cache_value))
+        except TimeoutError:
+            logger.warning("Redis query timed out, retrying...")
+            raise
+        except Exception as e:
+            logger.exception(e)
+
+    get_cached_embeddings()
+
+    # not stored in cache, call openai
     batch = [
         text for i, text in enumerate(batch) if embeddings[i] is None
     ]  # remove all the cached values from the batch

Step 3: 🔄️ Validating

Your changes have been successfully made to the branch sweep/sweep_connection_closed_by_server_redis. I have validated these changes using a syntax checker and a linter.


Tip

To recreate the pull request, edit the issue title or description.

This is an automated message generated by Sweep AI.

from sweep.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.