nuojohnchen commited on
Commit
044d0d9
·
verified ·
1 Parent(s): 7ede7c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -154,11 +154,22 @@ def extract_text_from_pdf(pdf_bytes):
154
 
155
  # Print progress
156
  print(f"Processed page {page_num+1}/{len(doc)}")
 
 
 
 
 
157
 
158
  # Clear GPU memory
159
  del pixel_values, outputs
160
  torch.cuda.empty_cache()
161
 
 
 
 
 
 
 
162
  return full_text
163
  except Exception as e:
164
  import traceback
@@ -233,12 +244,12 @@ Focus on clear, concise, and evidence-based improvements that align with the ove
233
  add_generation_prompt=True
234
  )
235
 
236
- # Check input length and truncate to 15000 tokens before encoding
237
  input_tokens = tokenizer.encode(text)
238
- if len(input_tokens) > 15000: # Limit to 15k tokens
239
- input_tokens = input_tokens[:15000]
240
  text = tokenizer.decode(input_tokens)
241
- print(f"Input truncated to 15000 tokens")
242
 
243
  progress(0.5, desc="Generating improved text...")
244
  # Generate non-streaming
@@ -250,7 +261,7 @@ Focus on clear, concise, and evidence-based improvements that align with the ove
250
  with torch.no_grad():
251
  output_ids = model.generate(
252
  input_ids,
253
- attention_mask=attention_mask, # 添加注意力掩码
254
  max_new_tokens=max_new_tokens,
255
  do_sample=(temperature > 0),
256
  temperature=temperature if temperature > 0 else 1.0,
 
154
 
155
  # Print progress
156
  print(f"Processed page {page_num+1}/{len(doc)}")
157
+
158
+ # 检查是否已经达到15000个token的限制
159
+ if len(full_text.split()) > 15000:
160
+ print("Reached 15000 token limit, stopping extraction")
161
+ break
162
 
163
  # Clear GPU memory
164
  del pixel_values, outputs
165
  torch.cuda.empty_cache()
166
 
167
+ # 确保不超过15000个token
168
+ words = full_text.split()
169
+ if len(words) > 15000:
170
+ full_text = " ".join(words[:15000])
171
+ print(f"Truncated paper content to 15000 tokens")
172
+
173
  return full_text
174
  except Exception as e:
175
  import traceback
 
244
  add_generation_prompt=True
245
  )
246
 
247
+ # Check input length and truncate to 16384 tokens before encoding
248
  input_tokens = tokenizer.encode(text)
249
+ if len(input_tokens) > 16384: # 模型的最大上下文长度
250
+ input_tokens = input_tokens[:16384]
251
  text = tokenizer.decode(input_tokens)
252
+ print(f"Input truncated to 16384 tokens")
253
 
254
  progress(0.5, desc="Generating improved text...")
255
  # Generate non-streaming
 
261
  with torch.no_grad():
262
  output_ids = model.generate(
263
  input_ids,
264
+ attention_mask=attention_mask,
265
  max_new_tokens=max_new_tokens,
266
  do_sample=(temperature > 0),
267
  temperature=temperature if temperature > 0 else 1.0,