Cache Augmented Generation: An Innovative Alternative to Traditional RAG, But Still a Long Way to Go
Cache Augmented Generation can be future If LLMs are small and fast.
"Don't do RAG, do CAG instead" has been flooding Twitter and LinkedIn lately, and predictably, people started declaring RAG dead because of this. But here's the thing - Cache Augmented Generation (CAG) isn't some revolutionary new approach. It's been quietly powering the systems of major players like OpenAI and Anthropic for a while now, and Google was actually the pioneer in bringing caching back into the spotlight. Despite all the buzz, I found myself frustrated trying to find a straightforward implementation. The original paper's approach feels unnecessarily complex and doesn't scale well in practice. So I decided to build my own implementation, focusing on simplicity and practicality.
Think about how your brain works when having a conversation. You don't recalculate everything you know about the topic each time you speak - you keep relevant information readily available in your working memory. That's essentially what KV caching does for language models.
I built a simple system to show how this works, and I was surprised by how straightforward it was. Let's start with the basic setup:
from transformers import BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
import torch
# Define model name
model_name = "meta-llama/Llama-3.2-1B-Instruct"
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
)
Nothing fancy here - we're just loading a smaller version of Llama. I chose this model because it's like a compact car: it might not be as powerful as a Tesla, but it gets the job done and you can actually park it in your garage (or in this case, run it on your laptop).
The real magic happens in how we process and store information:
def preprocess_knowledge(
model,
tokenizer,
prompt: str,
) -> DynamicCache:
embed_device = model.model.embed_tokens.weight.device
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(embed_device)
past_key_values = DynamicCache()
with torch.no_grad():
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
use_cache=True,
output_attentions=False,
output_hidden_states=False
)
return outputs.past_key_values
This function is like creating a cheat sheet before an exam - it processes the information once and keeps the important bits readily available. The neat part is that we're not doing anything particularly complex; we're just being smart about storing what we've already calculated.
And here's how we use that stored information:
def generate(
model,
input_ids: torch.Tensor,
past_key_values,
max_new_tokens: int = 3000
) -> torch.Tensor:
embed_device = model.model.embed_tokens.weight.device
origin_ids = input_ids
input_ids = input_ids.to(embed_device)
output_ids = input_ids.clone()
next_token = input_ids
with torch.no_grad():
for _ in range(max_new_tokens):
outputs = model(
input_ids=next_token,
past_key_values=past_key_values,
use_cache=True
)
next_token_logits = outputs.logits[:, -1, :]
next_token = next_token_logits.argmax(dim=-1).unsqueeze(-1)
next_token = next_token.to(embed_device)
past_key_values = outputs.past_key_values
output_ids = torch.cat([output_ids, next_token], dim=1)
if next_token.item() in model.config.eos_token_id:
break
return output_ids[:, origin_ids.shape[-1]:]
Now, there's a catch - and it's a pretty significant one. This approach only works well with relatively small amounts of text - about 60 pages worth. That might sound like a lot, but in the world of enterprise document processing, it's tiny. It's like having a really efficient filing system that can only handle one drawer's worth of documents.
But here's the interesting part: this limitation might not matter for many real-world applications. Take customer support, for instance. When someone contacts support, they usually have a specific problem about a specific product. You don't need the entire company knowledge base in memory - just the relevant parts about that product.
I've seen this pattern before: sometimes constraints force you to be smarter about how you solve problems. Instead of trying to load everything into memory, you get better at picking what's relevant. It's like the difference between memorizing an entire textbook and just knowing which chapter to reference.
This approach really shines in situations where you need quick, focused responses based on a manageable amount of context. Think of it as having a really knowledgeable assistant who's great at handling one topic at a time, rather than trying to be an expert on everything simultaneously.
For small businesses or specific departments within larger organizations, this could be exactly what they need. It's not going to replace huge enterprise systems, but it might be perfect for that customer service team helping people troubleshoot their coffee makers, or that internal documentation system where people look up company policies.
The key insight here isn't just about the technical implementation - it's about recognizing that sometimes, working within constraints leads to more practical solutions. You might not be able to process your entire document database at once, but do you really need to?