@@ -4,6 +4,7 @@ use std::sync::{Arc, RwLock};
44
55use bimap:: BiMap ;
66use dashmap:: DashMap ;
7+ use rustc_hash:: FxHashSet ;
78use tower_lsp_server:: ls_types:: DocumentRangeFormattingParams ;
89use url:: Url ;
910
@@ -200,14 +201,12 @@ impl LanguageServer for Backend {
200201 & self ,
201202 params : ls_types:: DocumentFormattingParams ,
202203 ) -> jsonrpc:: Result < Option < Vec < ls_types:: TextEdit > > > {
203- if ! self
204+ if self
204205 . error_map
205206 . get ( & params. text_document . uri . to_string ( ) )
206207 . unwrap ( )
207208 . iter ( )
208- . filter ( |e| matches ! ( e, LspError :: SyntaxError ( _) ) )
209- . collect :: < Vec < _ > > ( )
210- . is_empty ( )
209+ . any ( |e| matches ! ( e, LspError :: SyntaxError ( _) ) )
211210 {
212211 return Ok ( None ) ;
213212 }
@@ -360,15 +359,40 @@ impl Backend {
360359 . collect :: < Vec < _ > > ( ) ;
361360
362361 if errors. is_empty ( ) && self . config . enable_type_checking {
362+ let hir_guard = self . hir . read ( ) . unwrap ( ) ;
363363 let mut checker = mq_check:: TypeChecker :: with_options ( self . config . type_checker_options ) ;
364- let type_errors = checker. check ( & self . hir . read ( ) . unwrap ( ) ) ;
364+ let type_errors = checker. check ( & hir_guard) ;
365+
366+ // Build a set of (line, column) start positions from the current source's symbols
367+ // so that type errors originating from other sources (e.g., pre-loaded modules)
368+ // are not incorrectly attributed to this file.
369+ let source_locations: FxHashSet < ( u32 , usize ) > = hir_guard
370+ . symbols_for_source ( source_id)
371+ . filter_map ( |( _, symbol) | {
372+ symbol
373+ . source
374+ . text_range
375+ . as_ref ( )
376+ . map ( |r| ( r. start . line , r. start . column ) )
377+ } )
378+ . collect ( ) ;
379+
365380 self . type_env_map
366381 . insert ( uri_string. clone ( ) , checker. symbol_types ( ) . clone ( ) ) ;
367- errors. extend ( type_errors. into_iter ( ) . map ( LspError :: TypeError ) ) ;
382+ errors. extend (
383+ type_errors
384+ . into_iter ( )
385+ . filter ( |e| {
386+ e. location ( )
387+ . map ( |( line, col) | source_locations. contains ( & ( line, col) ) )
388+ . unwrap_or ( false )
389+ } )
390+ . map ( LspError :: TypeError ) ,
391+ ) ;
368392 }
369393
370394 self . source_map . write ( ) . unwrap ( ) . insert ( uri_string. clone ( ) , source_id) ;
371- self . text_map . insert ( uri_string. clone ( ) , text. to_string ( ) . into ( ) ) ;
395+ self . text_map . insert ( uri_string. clone ( ) , text. into ( ) ) ;
372396 self . error_map . insert ( uri_string, errors) ;
373397 }
374398
@@ -381,28 +405,29 @@ impl Backend {
381405
382406 // Add parsing errors if they exist
383407 if let Some ( errors) = file_errors {
384- let errors: Vec < ls_types:: Diagnostic > = ( * errors) . iter ( ) . map ( Into :: into) . collect :: < Vec < _ > > ( ) ;
385- diagnostics. extend ( errors) ;
408+ diagnostics. extend ( errors. iter ( ) . map ( Into :: into) ) ;
386409 }
387410
388411 {
389412 let source_map_guard = self . source_map . read ( ) . unwrap ( ) ;
390413 if let Some ( source_id) = source_map_guard. get_by_left ( & uri_string) {
391414 let hir_guard = self . hir . read ( ) . unwrap ( ) ;
392415
393- // Build a map of text_range -> bool for this file's symbols for O(1) lookup
394- let mut range_map = std:: collections:: HashMap :: new ( ) ;
395- for ( _, symbol) in hir_guard. symbols ( ) {
396- if symbol. source . source_id == Some ( * source_id)
397- && let Some ( ref text_range) = symbol. source . text_range
398- {
399- range_map. insert ( text_range, true ) ;
400- }
401- }
416+ // Build a set of text_ranges for this file's symbols for O(1) lookup
417+ let range_set: FxHashSet < mq_lang:: Range > = hir_guard
418+ . symbols ( )
419+ . filter_map ( |( _, symbol) | {
420+ if symbol. source . source_id == Some ( * source_id) {
421+ symbol. source . text_range
422+ } else {
423+ None
424+ }
425+ } )
426+ . collect ( ) ;
402427
403428 // Filter HIR errors to only include ones from this specific source
404429 diagnostics. extend ( hir_guard. error_ranges ( ) . into_iter ( ) . filter_map ( |( message, item) | {
405- if range_map . contains_key ( & item) {
430+ if range_set . contains ( & item) {
406431 Some ( ls_types:: Diagnostic :: new_simple (
407432 ls_types:: Range :: new (
408433 ls_types:: Position {
@@ -448,7 +473,7 @@ impl Backend {
448473
449474 // Add HIR warnings (including unreachable code warnings)
450475 diagnostics. extend ( hir_guard. warning_ranges ( ) . into_iter ( ) . filter_map ( |( message, item) | {
451- if range_map . contains_key ( & item) {
476+ if range_set . contains ( & item) {
452477 let mut diagnostic = ls_types:: Diagnostic :: new_simple (
453478 ls_types:: Range :: new (
454479 ls_types:: Position {
0 commit comments