AI_tools
1from abc import ABC, abstractmethod 2from dataclasses import dataclass 3import os 4from typing import cast 5 6from dotenv import load_dotenv 7from google import genai 8from google.genai import types 9from openai import OpenAI 10from pydantic import BaseModel 11 12from config import ENV_PATH 13 14_ = load_dotenv(ENV_PATH) 15 16 17@dataclass 18class PaperContent: 19 """ 20 Wraps available content for a paper. 21 Functions use the richest source available: pdf > full_text > abstract. 22 Populate full_text (TeX) or pdf (bytes) via downloads.py when needed. 23 """ 24 abstract: str 25 full_text: str | None = None # raw TeX source 26 pdf: bytes | None = None # PDF bytes 27 28 def best_text(self) -> str: 29 if self.full_text: 30 return self.full_text 31 return self.abstract 32 33 34# Response schemas 35 36class _TagResponse(BaseModel): 37 tags: list[str] 38 39 40class SummaryResult(BaseModel): 41 tldr: str 42 key_contributions: list[str] 43 44 45class _RelatedResponse(BaseModel): 46 related_ids: list[str] 47 48 49# --------------------------------------------------------------------------- 50# AIProvider ABC 51# --------------------------------------------------------------------------- 52 53class AIProvider(ABC): 54 """Unified interface for AI-powered paper analysis.""" 55 56 @abstractmethod 57 def tag(self, content: PaperContent) -> list[str]: 58 """Generate 3-5 relevant tags for a paper.""" 59 60 @abstractmethod 61 def summarize(self, content: PaperContent) -> SummaryResult: 62 """Return a one-sentence TLDR and 2-4 key contributions.""" 63 64 @abstractmethod 65 def find_related( 66 self, 67 content: PaperContent, 68 candidates: list[tuple[str, str]], 69 threshold: int = 5, 70 ) -> list[str]: 71 """Return IDs of the most conceptually related papers from candidates.""" 72 73 74# --------------------------------------------------------------------------- 75# GeminiProvider 76# --------------------------------------------------------------------------- 77 78class GeminiProvider(AIProvider): 79 """AI provider backed by Google Gemini.""" 80 81 def __init__(self) -> None: 82 self._client = None # lazy init 83 84 def _get_client(self): 85 if self._client is None: 86 api_key = os.getenv("GENAI_API_KEY_TAG_GEN") 87 if not api_key: 88 raise EnvironmentError("GENAI_API_KEY_TAG_GEN not set.") 89 self._client = genai.Client(api_key=api_key) 90 return self._client 91 92 def _generate(self, prompt: str, content: PaperContent, schema: type[BaseModel]) -> BaseModel: 93 parts: list = [types.Part.from_text(text=prompt)] 94 if content.pdf: 95 parts.append(types.Part.from_bytes(data=content.pdf, mime_type="application/pdf")) 96 elif content.full_text: 97 parts.append(types.Part.from_text(text=content.full_text)) 98 else: 99 parts.append(types.Part.from_text(text=content.abstract)) 100 response = self._get_client().models.generate_content( 101 model="gemini-2.0-flash", 102 contents=parts, 103 config=types.GenerateContentConfig( 104 response_mime_type="application/json", 105 response_schema=schema, 106 ), 107 ) 108 return cast(BaseModel, response.parsed) 109 110 def tag(self, content: PaperContent) -> list[str]: 111 parsed = cast(_TagResponse, self._generate( 112 "Generate 3-5 relevant Obsidian tags for this academic paper.", 113 content, _TagResponse, 114 )) 115 return [f"#{t.strip().lstrip('#').replace(' ', '_')}" for t in parsed.tags] 116 117 def summarize(self, content: PaperContent) -> SummaryResult: 118 return cast(SummaryResult, self._generate( 119 "Summarize this academic paper into a one-sentence TLDR and 2-4 key contributions.", 120 content, SummaryResult, 121 )) 122 123 def find_related( 124 self, 125 content: PaperContent, 126 candidates: list[tuple[str, str]], 127 threshold: int = 5, 128 ) -> list[str]: 129 candidate_block = "\n\n".join( 130 f"ID: {pid}\n{ab}" for pid, ab in candidates[:40]) 131 parsed = cast(_RelatedResponse, self._generate( 132 f"Which of the following papers are most conceptually related to this one? " 133 f"Return up to {threshold} paper IDs.\n\n{candidate_block}", 134 content, _RelatedResponse, 135 )) 136 return parsed.related_ids 137 138 139# --------------------------------------------------------------------------- 140# OpenAIProvider 141# --------------------------------------------------------------------------- 142 143class OpenAIProvider(AIProvider): 144 """AI provider backed by OpenAI (GPT-4o, etc.).""" 145 146 def __init__(self, model: str = "gpt-4o-mini") -> None: 147 self._client = None # lazy init 148 self._model = model 149 150 def _get_client(self): 151 if self._client is None: 152 api_key = os.getenv("OPENAI_API_KEY") 153 if not api_key: 154 raise EnvironmentError("OPENAI_API_KEY not set.") 155 self._client = OpenAI(api_key=api_key) 156 return self._client 157 158 def _generate(self, prompt: str, content: PaperContent, schema: type[BaseModel]) -> BaseModel: 159 text = content.best_text() 160 response = self._get_client().beta.chat.completions.parse( 161 model=self._model, 162 messages=[ 163 {"role": "system", "content": prompt}, 164 {"role": "user", "content": text}, 165 ], 166 response_format=schema, 167 ) 168 result = response.choices[0].message.parsed 169 assert result is not None 170 return result 171 172 def tag(self, content: PaperContent) -> list[str]: 173 parsed = cast(_TagResponse, self._generate( 174 "Generate 3-5 relevant Obsidian tags for this academic paper. " 175 "Return as JSON with a 'tags' array of strings.", 176 content, _TagResponse, 177 )) 178 return [f"#{t.strip().lstrip('#').replace(' ', '_')}" for t in parsed.tags] 179 180 def summarize(self, content: PaperContent) -> SummaryResult: 181 return cast(SummaryResult, self._generate( 182 "Summarize this academic paper into a one-sentence TLDR and 2-4 key contributions. " 183 "Return as JSON with 'tldr' (string) and 'key_contributions' (array of strings).", 184 content, SummaryResult, 185 )) 186 187 def find_related( 188 self, 189 content: PaperContent, 190 candidates: list[tuple[str, str]], 191 threshold: int = 5, 192 ) -> list[str]: 193 candidate_block = "\n\n".join( 194 f"ID: {pid}\n{ab}" for pid, ab in candidates[:40]) 195 parsed = cast(_RelatedResponse, self._generate( 196 f"Which of the following papers are most conceptually related to this one? " 197 f"Return up to {threshold} paper IDs as JSON with a 'related_ids' array.\n\n{candidate_block}", 198 content, _RelatedResponse, 199 )) 200 return parsed.related_ids 201 202 203# --------------------------------------------------------------------------- 204# Module-level active provider + public API 205# --------------------------------------------------------------------------- 206 207_provider: AIProvider | None = None 208 209 210def _get_provider() -> AIProvider: 211 global _provider 212 if _provider is None: 213 _provider = GeminiProvider() 214 return _provider 215 216 217def set_provider(provider: AIProvider) -> None: 218 """Switch the active AI provider.""" 219 global _provider 220 _provider = provider 221 222 223def tag(content: PaperContent, file_path: str | None = None) -> list[str]: 224 """Generate 3-5 Obsidian tags. Optionally append to file_path.""" 225 tags = _get_provider().tag(content) 226 if file_path: 227 with open(file_path, "a", encoding="utf-8") as f: 228 f.write("\n" + " ".join(tags)) 229 return tags 230 231 232def summarize(content: PaperContent) -> SummaryResult: 233 """Return a one-sentence TLDR and 2-4 key contributions.""" 234 return _get_provider().summarize(content) 235 236 237def find_related( 238 content: PaperContent, 239 candidates: list[tuple[str, str]], # [(paper_id, abstract), ...] 240 threshold: int = 5, 241) -> list[str]: 242 """ 243 Return IDs of the most conceptually related papers from candidates. 244 Useful for adding semantic edges to the graph beyond shared category/author. 245 """ 246 return _get_provider().find_related(content, candidates, threshold)
18@dataclass 19class PaperContent: 20 """ 21 Wraps available content for a paper. 22 Functions use the richest source available: pdf > full_text > abstract. 23 Populate full_text (TeX) or pdf (bytes) via downloads.py when needed. 24 """ 25 abstract: str 26 full_text: str | None = None # raw TeX source 27 pdf: bytes | None = None # PDF bytes 28 29 def best_text(self) -> str: 30 if self.full_text: 31 return self.full_text 32 return self.abstract
Wraps available content for a paper. Functions use the richest source available: pdf > full_text > abstract. Populate full_text (TeX) or pdf (bytes) via downloads.py when needed.
!!! abstract "Usage Documentation" Models
A base class for creating Pydantic models.
Attributes:
__class_vars__: The names of the class variables defined on the model.
__private_attributes__: Metadata about the private attributes of the model.
__signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The core schema of the model.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
__args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a [`RootModel`][pydantic.root_model.RootModel].
__pydantic_serializer__: The `pydantic-core` `SchemaSerializer` used to dump instances of the model.
__pydantic_validator__: The `pydantic-core` `SchemaValidator` used to validate instances of the model.
__pydantic_fields__: A dictionary of field names and their corresponding [`FieldInfo`][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__: A dictionary of computed field names and their corresponding [`ComputedFieldInfo`][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__: A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra]
is set to `'allow'`.
__pydantic_fields_set__: The names of fields explicitly set during instantiation.
__pydantic_private__: Values of private attributes set on the model instance.
54class AIProvider(ABC): 55 """Unified interface for AI-powered paper analysis.""" 56 57 @abstractmethod 58 def tag(self, content: PaperContent) -> list[str]: 59 """Generate 3-5 relevant tags for a paper.""" 60 61 @abstractmethod 62 def summarize(self, content: PaperContent) -> SummaryResult: 63 """Return a one-sentence TLDR and 2-4 key contributions.""" 64 65 @abstractmethod 66 def find_related( 67 self, 68 content: PaperContent, 69 candidates: list[tuple[str, str]], 70 threshold: int = 5, 71 ) -> list[str]: 72 """Return IDs of the most conceptually related papers from candidates."""
Unified interface for AI-powered paper analysis.
57 @abstractmethod 58 def tag(self, content: PaperContent) -> list[str]: 59 """Generate 3-5 relevant tags for a paper."""
Generate 3-5 relevant tags for a paper.
79class GeminiProvider(AIProvider): 80 """AI provider backed by Google Gemini.""" 81 82 def __init__(self) -> None: 83 self._client = None # lazy init 84 85 def _get_client(self): 86 if self._client is None: 87 api_key = os.getenv("GENAI_API_KEY_TAG_GEN") 88 if not api_key: 89 raise EnvironmentError("GENAI_API_KEY_TAG_GEN not set.") 90 self._client = genai.Client(api_key=api_key) 91 return self._client 92 93 def _generate(self, prompt: str, content: PaperContent, schema: type[BaseModel]) -> BaseModel: 94 parts: list = [types.Part.from_text(text=prompt)] 95 if content.pdf: 96 parts.append(types.Part.from_bytes(data=content.pdf, mime_type="application/pdf")) 97 elif content.full_text: 98 parts.append(types.Part.from_text(text=content.full_text)) 99 else: 100 parts.append(types.Part.from_text(text=content.abstract)) 101 response = self._get_client().models.generate_content( 102 model="gemini-2.0-flash", 103 contents=parts, 104 config=types.GenerateContentConfig( 105 response_mime_type="application/json", 106 response_schema=schema, 107 ), 108 ) 109 return cast(BaseModel, response.parsed) 110 111 def tag(self, content: PaperContent) -> list[str]: 112 parsed = cast(_TagResponse, self._generate( 113 "Generate 3-5 relevant Obsidian tags for this academic paper.", 114 content, _TagResponse, 115 )) 116 return [f"#{t.strip().lstrip('#').replace(' ', '_')}" for t in parsed.tags] 117 118 def summarize(self, content: PaperContent) -> SummaryResult: 119 return cast(SummaryResult, self._generate( 120 "Summarize this academic paper into a one-sentence TLDR and 2-4 key contributions.", 121 content, SummaryResult, 122 )) 123 124 def find_related( 125 self, 126 content: PaperContent, 127 candidates: list[tuple[str, str]], 128 threshold: int = 5, 129 ) -> list[str]: 130 candidate_block = "\n\n".join( 131 f"ID: {pid}\n{ab}" for pid, ab in candidates[:40]) 132 parsed = cast(_RelatedResponse, self._generate( 133 f"Which of the following papers are most conceptually related to this one? " 134 f"Return up to {threshold} paper IDs.\n\n{candidate_block}", 135 content, _RelatedResponse, 136 )) 137 return parsed.related_ids
AI provider backed by Google Gemini.
111 def tag(self, content: PaperContent) -> list[str]: 112 parsed = cast(_TagResponse, self._generate( 113 "Generate 3-5 relevant Obsidian tags for this academic paper.", 114 content, _TagResponse, 115 )) 116 return [f"#{t.strip().lstrip('#').replace(' ', '_')}" for t in parsed.tags]
Generate 3-5 relevant tags for a paper.
144class OpenAIProvider(AIProvider): 145 """AI provider backed by OpenAI (GPT-4o, etc.).""" 146 147 def __init__(self, model: str = "gpt-4o-mini") -> None: 148 self._client = None # lazy init 149 self._model = model 150 151 def _get_client(self): 152 if self._client is None: 153 api_key = os.getenv("OPENAI_API_KEY") 154 if not api_key: 155 raise EnvironmentError("OPENAI_API_KEY not set.") 156 self._client = OpenAI(api_key=api_key) 157 return self._client 158 159 def _generate(self, prompt: str, content: PaperContent, schema: type[BaseModel]) -> BaseModel: 160 text = content.best_text() 161 response = self._get_client().beta.chat.completions.parse( 162 model=self._model, 163 messages=[ 164 {"role": "system", "content": prompt}, 165 {"role": "user", "content": text}, 166 ], 167 response_format=schema, 168 ) 169 result = response.choices[0].message.parsed 170 assert result is not None 171 return result 172 173 def tag(self, content: PaperContent) -> list[str]: 174 parsed = cast(_TagResponse, self._generate( 175 "Generate 3-5 relevant Obsidian tags for this academic paper. " 176 "Return as JSON with a 'tags' array of strings.", 177 content, _TagResponse, 178 )) 179 return [f"#{t.strip().lstrip('#').replace(' ', '_')}" for t in parsed.tags] 180 181 def summarize(self, content: PaperContent) -> SummaryResult: 182 return cast(SummaryResult, self._generate( 183 "Summarize this academic paper into a one-sentence TLDR and 2-4 key contributions. " 184 "Return as JSON with 'tldr' (string) and 'key_contributions' (array of strings).", 185 content, SummaryResult, 186 )) 187 188 def find_related( 189 self, 190 content: PaperContent, 191 candidates: list[tuple[str, str]], 192 threshold: int = 5, 193 ) -> list[str]: 194 candidate_block = "\n\n".join( 195 f"ID: {pid}\n{ab}" for pid, ab in candidates[:40]) 196 parsed = cast(_RelatedResponse, self._generate( 197 f"Which of the following papers are most conceptually related to this one? " 198 f"Return up to {threshold} paper IDs as JSON with a 'related_ids' array.\n\n{candidate_block}", 199 content, _RelatedResponse, 200 )) 201 return parsed.related_ids
AI provider backed by OpenAI (GPT-4o, etc.).
173 def tag(self, content: PaperContent) -> list[str]: 174 parsed = cast(_TagResponse, self._generate( 175 "Generate 3-5 relevant Obsidian tags for this academic paper. " 176 "Return as JSON with a 'tags' array of strings.", 177 content, _TagResponse, 178 )) 179 return [f"#{t.strip().lstrip('#').replace(' ', '_')}" for t in parsed.tags]
Generate 3-5 relevant tags for a paper.
181 def summarize(self, content: PaperContent) -> SummaryResult: 182 return cast(SummaryResult, self._generate( 183 "Summarize this academic paper into a one-sentence TLDR and 2-4 key contributions. " 184 "Return as JSON with 'tldr' (string) and 'key_contributions' (array of strings).", 185 content, SummaryResult, 186 ))
Return a one-sentence TLDR and 2-4 key contributions.
218def set_provider(provider: AIProvider) -> None: 219 """Switch the active AI provider.""" 220 global _provider 221 _provider = provider
Switch the active AI provider.
224def tag(content: PaperContent, file_path: str | None = None) -> list[str]: 225 """Generate 3-5 Obsidian tags. Optionally append to file_path.""" 226 tags = _get_provider().tag(content) 227 if file_path: 228 with open(file_path, "a", encoding="utf-8") as f: 229 f.write("\n" + " ".join(tags)) 230 return tags
Generate 3-5 Obsidian tags. Optionally append to file_path.
233def summarize(content: PaperContent) -> SummaryResult: 234 """Return a one-sentence TLDR and 2-4 key contributions.""" 235 return _get_provider().summarize(content)
Return a one-sentence TLDR and 2-4 key contributions.