AI_tools

  1from abc import ABC, abstractmethod
  2from dataclasses import dataclass
  3import os
  4from typing import cast
  5
  6from dotenv import load_dotenv
  7from google import genai
  8from google.genai import types
  9from openai import OpenAI
 10from pydantic import BaseModel
 11
 12from config import ENV_PATH
 13
 14_ = load_dotenv(ENV_PATH)
 15
 16
 17@dataclass
 18class PaperContent:
 19    """
 20    Wraps available content for a paper.
 21    Functions use the richest source available: pdf > full_text > abstract.
 22    Populate full_text (TeX) or pdf (bytes) via downloads.py when needed.
 23    """
 24    abstract: str
 25    full_text: str | None = None  # raw TeX source
 26    pdf: bytes | None = None      # PDF bytes
 27
 28    def best_text(self) -> str:
 29        if self.full_text:
 30            return self.full_text
 31        return self.abstract
 32
 33
 34# Response schemas
 35
 36class _TagResponse(BaseModel):
 37    tags: list[str]
 38
 39
 40class SummaryResult(BaseModel):
 41    tldr: str
 42    key_contributions: list[str]
 43
 44
 45class _RelatedResponse(BaseModel):
 46    related_ids: list[str]
 47
 48
 49# ---------------------------------------------------------------------------
 50# AIProvider ABC
 51# ---------------------------------------------------------------------------
 52
 53class AIProvider(ABC):
 54    """Unified interface for AI-powered paper analysis."""
 55
 56    @abstractmethod
 57    def tag(self, content: PaperContent) -> list[str]:
 58        """Generate 3-5 relevant tags for a paper."""
 59
 60    @abstractmethod
 61    def summarize(self, content: PaperContent) -> SummaryResult:
 62        """Return a one-sentence TLDR and 2-4 key contributions."""
 63
 64    @abstractmethod
 65    def find_related(
 66        self,
 67        content: PaperContent,
 68        candidates: list[tuple[str, str]],
 69        threshold: int = 5,
 70    ) -> list[str]:
 71        """Return IDs of the most conceptually related papers from candidates."""
 72
 73
 74# ---------------------------------------------------------------------------
 75# GeminiProvider
 76# ---------------------------------------------------------------------------
 77
 78class GeminiProvider(AIProvider):
 79    """AI provider backed by Google Gemini."""
 80
 81    def __init__(self) -> None:
 82        self._client = None  # lazy init
 83
 84    def _get_client(self):
 85        if self._client is None:
 86            api_key = os.getenv("GENAI_API_KEY_TAG_GEN")
 87            if not api_key:
 88                raise EnvironmentError("GENAI_API_KEY_TAG_GEN not set.")
 89            self._client = genai.Client(api_key=api_key)
 90        return self._client
 91
 92    def _generate(self, prompt: str, content: PaperContent, schema: type[BaseModel]) -> BaseModel:
 93        parts: list = [types.Part.from_text(text=prompt)]
 94        if content.pdf:
 95            parts.append(types.Part.from_bytes(data=content.pdf, mime_type="application/pdf"))
 96        elif content.full_text:
 97            parts.append(types.Part.from_text(text=content.full_text))
 98        else:
 99            parts.append(types.Part.from_text(text=content.abstract))
100        response = self._get_client().models.generate_content(
101            model="gemini-2.0-flash",
102            contents=parts,
103            config=types.GenerateContentConfig(
104                response_mime_type="application/json",
105                response_schema=schema,
106            ),
107        )
108        return cast(BaseModel, response.parsed)
109
110    def tag(self, content: PaperContent) -> list[str]:
111        parsed = cast(_TagResponse, self._generate(
112            "Generate 3-5 relevant Obsidian tags for this academic paper.",
113            content, _TagResponse,
114        ))
115        return [f"#{t.strip().lstrip('#').replace(' ', '_')}" for t in parsed.tags]
116
117    def summarize(self, content: PaperContent) -> SummaryResult:
118        return cast(SummaryResult, self._generate(
119            "Summarize this academic paper into a one-sentence TLDR and 2-4 key contributions.",
120            content, SummaryResult,
121        ))
122
123    def find_related(
124        self,
125        content: PaperContent,
126        candidates: list[tuple[str, str]],
127        threshold: int = 5,
128    ) -> list[str]:
129        candidate_block = "\n\n".join(
130            f"ID: {pid}\n{ab}" for pid, ab in candidates[:40])
131        parsed = cast(_RelatedResponse, self._generate(
132            f"Which of the following papers are most conceptually related to this one? "
133            f"Return up to {threshold} paper IDs.\n\n{candidate_block}",
134            content, _RelatedResponse,
135        ))
136        return parsed.related_ids
137
138
139# ---------------------------------------------------------------------------
140# OpenAIProvider
141# ---------------------------------------------------------------------------
142
143class OpenAIProvider(AIProvider):
144    """AI provider backed by OpenAI (GPT-4o, etc.)."""
145
146    def __init__(self, model: str = "gpt-4o-mini") -> None:
147        self._client = None  # lazy init
148        self._model = model
149
150    def _get_client(self):
151        if self._client is None:
152            api_key = os.getenv("OPENAI_API_KEY")
153            if not api_key:
154                raise EnvironmentError("OPENAI_API_KEY not set.")
155            self._client = OpenAI(api_key=api_key)
156        return self._client
157
158    def _generate(self, prompt: str, content: PaperContent, schema: type[BaseModel]) -> BaseModel:
159        text = content.best_text()
160        response = self._get_client().beta.chat.completions.parse(
161            model=self._model,
162            messages=[
163                {"role": "system", "content": prompt},
164                {"role": "user", "content": text},
165            ],
166            response_format=schema,
167        )
168        result = response.choices[0].message.parsed
169        assert result is not None
170        return result
171
172    def tag(self, content: PaperContent) -> list[str]:
173        parsed = cast(_TagResponse, self._generate(
174            "Generate 3-5 relevant Obsidian tags for this academic paper. "
175            "Return as JSON with a 'tags' array of strings.",
176            content, _TagResponse,
177        ))
178        return [f"#{t.strip().lstrip('#').replace(' ', '_')}" for t in parsed.tags]
179
180    def summarize(self, content: PaperContent) -> SummaryResult:
181        return cast(SummaryResult, self._generate(
182            "Summarize this academic paper into a one-sentence TLDR and 2-4 key contributions. "
183            "Return as JSON with 'tldr' (string) and 'key_contributions' (array of strings).",
184            content, SummaryResult,
185        ))
186
187    def find_related(
188        self,
189        content: PaperContent,
190        candidates: list[tuple[str, str]],
191        threshold: int = 5,
192    ) -> list[str]:
193        candidate_block = "\n\n".join(
194            f"ID: {pid}\n{ab}" for pid, ab in candidates[:40])
195        parsed = cast(_RelatedResponse, self._generate(
196            f"Which of the following papers are most conceptually related to this one? "
197            f"Return up to {threshold} paper IDs as JSON with a 'related_ids' array.\n\n{candidate_block}",
198            content, _RelatedResponse,
199        ))
200        return parsed.related_ids
201
202
203# ---------------------------------------------------------------------------
204# Module-level active provider + public API
205# ---------------------------------------------------------------------------
206
207_provider: AIProvider | None = None
208
209
210def _get_provider() -> AIProvider:
211    global _provider
212    if _provider is None:
213        _provider = GeminiProvider()
214    return _provider
215
216
217def set_provider(provider: AIProvider) -> None:
218    """Switch the active AI provider."""
219    global _provider
220    _provider = provider
221
222
223def tag(content: PaperContent, file_path: str | None = None) -> list[str]:
224    """Generate 3-5 Obsidian tags. Optionally append to file_path."""
225    tags = _get_provider().tag(content)
226    if file_path:
227        with open(file_path, "a", encoding="utf-8") as f:
228            f.write("\n" + " ".join(tags))
229    return tags
230
231
232def summarize(content: PaperContent) -> SummaryResult:
233    """Return a one-sentence TLDR and 2-4 key contributions."""
234    return _get_provider().summarize(content)
235
236
237def find_related(
238    content: PaperContent,
239    candidates: list[tuple[str, str]],   # [(paper_id, abstract), ...]
240    threshold: int = 5,
241) -> list[str]:
242    """
243    Return IDs of the most conceptually related papers from candidates.
244    Useful for adding semantic edges to the graph beyond shared category/author.
245    """
246    return _get_provider().find_related(content, candidates, threshold)
@dataclass
class PaperContent:
18@dataclass
19class PaperContent:
20    """
21    Wraps available content for a paper.
22    Functions use the richest source available: pdf > full_text > abstract.
23    Populate full_text (TeX) or pdf (bytes) via downloads.py when needed.
24    """
25    abstract: str
26    full_text: str | None = None  # raw TeX source
27    pdf: bytes | None = None      # PDF bytes
28
29    def best_text(self) -> str:
30        if self.full_text:
31            return self.full_text
32        return self.abstract

Wraps available content for a paper. Functions use the richest source available: pdf > full_text > abstract. Populate full_text (TeX) or pdf (bytes) via downloads.py when needed.

PaperContent( abstract: str, full_text: str | None = None, pdf: bytes | None = None)
abstract: str
full_text: str | None = None
pdf: bytes | None = None
def best_text(self) -> str:
29    def best_text(self) -> str:
30        if self.full_text:
31            return self.full_text
32        return self.abstract
class SummaryResult(pydantic.main.BaseModel):
41class SummaryResult(BaseModel):
42    tldr: str
43    key_contributions: list[str]

!!! abstract "Usage Documentation" Models

A base class for creating Pydantic models.

Attributes: __class_vars__: The names of the class variables defined on the model. __private_attributes__: Metadata about the private attributes of the model. __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.

__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The core schema of the model.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
    This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
    __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a [`RootModel`][pydantic.root_model.RootModel].
__pydantic_serializer__: The `pydantic-core` `SchemaSerializer` used to dump instances of the model.
__pydantic_validator__: The `pydantic-core` `SchemaValidator` used to validate instances of the model.

__pydantic_fields__: A dictionary of field names and their corresponding [`FieldInfo`][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__: A dictionary of computed field names and their corresponding [`ComputedFieldInfo`][pydantic.fields.ComputedFieldInfo] objects.

__pydantic_extra__: A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra]
    is set to `'allow'`.
__pydantic_fields_set__: The names of fields explicitly set during instantiation.
__pydantic_private__: Values of private attributes set on the model instance.
tldr: str = PydanticUndefined
key_contributions: list[str] = PydanticUndefined
class AIProvider(abc.ABC):
54class AIProvider(ABC):
55    """Unified interface for AI-powered paper analysis."""
56
57    @abstractmethod
58    def tag(self, content: PaperContent) -> list[str]:
59        """Generate 3-5 relevant tags for a paper."""
60
61    @abstractmethod
62    def summarize(self, content: PaperContent) -> SummaryResult:
63        """Return a one-sentence TLDR and 2-4 key contributions."""
64
65    @abstractmethod
66    def find_related(
67        self,
68        content: PaperContent,
69        candidates: list[tuple[str, str]],
70        threshold: int = 5,
71    ) -> list[str]:
72        """Return IDs of the most conceptually related papers from candidates."""

Unified interface for AI-powered paper analysis.

@abstractmethod
def tag(self, content: PaperContent) -> list[str]:
57    @abstractmethod
58    def tag(self, content: PaperContent) -> list[str]:
59        """Generate 3-5 relevant tags for a paper."""

Generate 3-5 relevant tags for a paper.

@abstractmethod
def summarize(self, content: PaperContent) -> SummaryResult:
61    @abstractmethod
62    def summarize(self, content: PaperContent) -> SummaryResult:
63        """Return a one-sentence TLDR and 2-4 key contributions."""

Return a one-sentence TLDR and 2-4 key contributions.

class GeminiProvider(AIProvider):
 79class GeminiProvider(AIProvider):
 80    """AI provider backed by Google Gemini."""
 81
 82    def __init__(self) -> None:
 83        self._client = None  # lazy init
 84
 85    def _get_client(self):
 86        if self._client is None:
 87            api_key = os.getenv("GENAI_API_KEY_TAG_GEN")
 88            if not api_key:
 89                raise EnvironmentError("GENAI_API_KEY_TAG_GEN not set.")
 90            self._client = genai.Client(api_key=api_key)
 91        return self._client
 92
 93    def _generate(self, prompt: str, content: PaperContent, schema: type[BaseModel]) -> BaseModel:
 94        parts: list = [types.Part.from_text(text=prompt)]
 95        if content.pdf:
 96            parts.append(types.Part.from_bytes(data=content.pdf, mime_type="application/pdf"))
 97        elif content.full_text:
 98            parts.append(types.Part.from_text(text=content.full_text))
 99        else:
100            parts.append(types.Part.from_text(text=content.abstract))
101        response = self._get_client().models.generate_content(
102            model="gemini-2.0-flash",
103            contents=parts,
104            config=types.GenerateContentConfig(
105                response_mime_type="application/json",
106                response_schema=schema,
107            ),
108        )
109        return cast(BaseModel, response.parsed)
110
111    def tag(self, content: PaperContent) -> list[str]:
112        parsed = cast(_TagResponse, self._generate(
113            "Generate 3-5 relevant Obsidian tags for this academic paper.",
114            content, _TagResponse,
115        ))
116        return [f"#{t.strip().lstrip('#').replace(' ', '_')}" for t in parsed.tags]
117
118    def summarize(self, content: PaperContent) -> SummaryResult:
119        return cast(SummaryResult, self._generate(
120            "Summarize this academic paper into a one-sentence TLDR and 2-4 key contributions.",
121            content, SummaryResult,
122        ))
123
124    def find_related(
125        self,
126        content: PaperContent,
127        candidates: list[tuple[str, str]],
128        threshold: int = 5,
129    ) -> list[str]:
130        candidate_block = "\n\n".join(
131            f"ID: {pid}\n{ab}" for pid, ab in candidates[:40])
132        parsed = cast(_RelatedResponse, self._generate(
133            f"Which of the following papers are most conceptually related to this one? "
134            f"Return up to {threshold} paper IDs.\n\n{candidate_block}",
135            content, _RelatedResponse,
136        ))
137        return parsed.related_ids

AI provider backed by Google Gemini.

def tag(self, content: PaperContent) -> list[str]:
111    def tag(self, content: PaperContent) -> list[str]:
112        parsed = cast(_TagResponse, self._generate(
113            "Generate 3-5 relevant Obsidian tags for this academic paper.",
114            content, _TagResponse,
115        ))
116        return [f"#{t.strip().lstrip('#').replace(' ', '_')}" for t in parsed.tags]

Generate 3-5 relevant tags for a paper.

def summarize(self, content: PaperContent) -> SummaryResult:
118    def summarize(self, content: PaperContent) -> SummaryResult:
119        return cast(SummaryResult, self._generate(
120            "Summarize this academic paper into a one-sentence TLDR and 2-4 key contributions.",
121            content, SummaryResult,
122        ))

Return a one-sentence TLDR and 2-4 key contributions.

class OpenAIProvider(AIProvider):
144class OpenAIProvider(AIProvider):
145    """AI provider backed by OpenAI (GPT-4o, etc.)."""
146
147    def __init__(self, model: str = "gpt-4o-mini") -> None:
148        self._client = None  # lazy init
149        self._model = model
150
151    def _get_client(self):
152        if self._client is None:
153            api_key = os.getenv("OPENAI_API_KEY")
154            if not api_key:
155                raise EnvironmentError("OPENAI_API_KEY not set.")
156            self._client = OpenAI(api_key=api_key)
157        return self._client
158
159    def _generate(self, prompt: str, content: PaperContent, schema: type[BaseModel]) -> BaseModel:
160        text = content.best_text()
161        response = self._get_client().beta.chat.completions.parse(
162            model=self._model,
163            messages=[
164                {"role": "system", "content": prompt},
165                {"role": "user", "content": text},
166            ],
167            response_format=schema,
168        )
169        result = response.choices[0].message.parsed
170        assert result is not None
171        return result
172
173    def tag(self, content: PaperContent) -> list[str]:
174        parsed = cast(_TagResponse, self._generate(
175            "Generate 3-5 relevant Obsidian tags for this academic paper. "
176            "Return as JSON with a 'tags' array of strings.",
177            content, _TagResponse,
178        ))
179        return [f"#{t.strip().lstrip('#').replace(' ', '_')}" for t in parsed.tags]
180
181    def summarize(self, content: PaperContent) -> SummaryResult:
182        return cast(SummaryResult, self._generate(
183            "Summarize this academic paper into a one-sentence TLDR and 2-4 key contributions. "
184            "Return as JSON with 'tldr' (string) and 'key_contributions' (array of strings).",
185            content, SummaryResult,
186        ))
187
188    def find_related(
189        self,
190        content: PaperContent,
191        candidates: list[tuple[str, str]],
192        threshold: int = 5,
193    ) -> list[str]:
194        candidate_block = "\n\n".join(
195            f"ID: {pid}\n{ab}" for pid, ab in candidates[:40])
196        parsed = cast(_RelatedResponse, self._generate(
197            f"Which of the following papers are most conceptually related to this one? "
198            f"Return up to {threshold} paper IDs as JSON with a 'related_ids' array.\n\n{candidate_block}",
199            content, _RelatedResponse,
200        ))
201        return parsed.related_ids

AI provider backed by OpenAI (GPT-4o, etc.).

OpenAIProvider(model: str = 'gpt-4o-mini')
147    def __init__(self, model: str = "gpt-4o-mini") -> None:
148        self._client = None  # lazy init
149        self._model = model
def tag(self, content: PaperContent) -> list[str]:
173    def tag(self, content: PaperContent) -> list[str]:
174        parsed = cast(_TagResponse, self._generate(
175            "Generate 3-5 relevant Obsidian tags for this academic paper. "
176            "Return as JSON with a 'tags' array of strings.",
177            content, _TagResponse,
178        ))
179        return [f"#{t.strip().lstrip('#').replace(' ', '_')}" for t in parsed.tags]

Generate 3-5 relevant tags for a paper.

def summarize(self, content: PaperContent) -> SummaryResult:
181    def summarize(self, content: PaperContent) -> SummaryResult:
182        return cast(SummaryResult, self._generate(
183            "Summarize this academic paper into a one-sentence TLDR and 2-4 key contributions. "
184            "Return as JSON with 'tldr' (string) and 'key_contributions' (array of strings).",
185            content, SummaryResult,
186        ))

Return a one-sentence TLDR and 2-4 key contributions.

def set_provider(provider: AIProvider) -> None:
218def set_provider(provider: AIProvider) -> None:
219    """Switch the active AI provider."""
220    global _provider
221    _provider = provider

Switch the active AI provider.

def tag( content: PaperContent, file_path: str | None = None) -> list[str]:
224def tag(content: PaperContent, file_path: str | None = None) -> list[str]:
225    """Generate 3-5 Obsidian tags. Optionally append to file_path."""
226    tags = _get_provider().tag(content)
227    if file_path:
228        with open(file_path, "a", encoding="utf-8") as f:
229            f.write("\n" + " ".join(tags))
230    return tags

Generate 3-5 Obsidian tags. Optionally append to file_path.

def summarize(content: PaperContent) -> SummaryResult:
233def summarize(content: PaperContent) -> SummaryResult:
234    """Return a one-sentence TLDR and 2-4 key contributions."""
235    return _get_provider().summarize(content)

Return a one-sentence TLDR and 2-4 key contributions.